Skip to content

Commit 7498de3

Browse files
committed
Update readme
1 parent 724a8ed commit 7498de3

File tree

2 files changed

+110
-8
lines changed

2 files changed

+110
-8
lines changed

README.md

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
<!-- TITLE -->
2-
# Positional Encoding Benchmark for Time Series Classification
2+
## Positional Encoding Benchmark for Time Series Classification
3+
34

4-
[![arXiv](https://img.shields.io/badge/arXiv-2502.12370-b31b1b.svg)](https://arxiv.org/abs/2502.12370)
5-
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
65
[![Python 3.10](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org/downloads/release/python-3100/)
7-
[![PyTorch](https://img.shields.io/badge/PyTorch-2.4.1-ee4c2c.svg)](https://pytorch.org/)
6+
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/)
7+
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
8+
[![arXiv](https://img.shields.io/badge/arXiv-2502.12370-b31b1b.svg)](https://arxiv.org/abs/2502.12370)
89

910
This repository provides a comprehensive evaluation framework for positional encoding methods in transformer-based time series models, along with implementations and benchmarking results.
1011

@@ -70,9 +71,34 @@ python examples/run_benchmark.py
7071
python examples/run_benchmark.py --config path/to/custom_config.yaml
7172
```
7273

74+
## Usage
75+
```python
76+
from encodings.positional_encodings import PE_Name
77+
from models.transformer import TimeSeriesTransformer
78+
79+
# Use in transformer
80+
model = TimeSeriesTransformer(
81+
input_timesteps= SEQ_LENGTH, # Sequence length
82+
in_channels= INPUT_CHANNELS, # Number of input channels
83+
patch_size=PATCH_SIZE, # Patch size for embedding
84+
embedding_dim=EMBED_DIM. # Embedding dimension
85+
num_transformer_layers=NUM_LAYERS, # Number of transformer layers (4, 8, etc.)
86+
num_heads=N_HEADS, # Number of attention heads
87+
num_layers=NUM_LAYERS, # Number of transformer layers
88+
dim_feedforward=DIM_FF, # Feedforward dimension
89+
dropout=DROPOUT, # Dropout rate (0.1, 0.2, etc.)
90+
num_classes= NUM_CLASSES # Number of output classes
91+
pos_encoding='PE_Name', # Positional encoding type
92+
)
93+
94+
# Forward pass
95+
x = torch.randn(BATCH_SIZE, SEQ_LENGTH, INPUT_CHANNELS) # (batch, sequence, features)
96+
output = model(x)
97+
```
98+
7399
## Results
74100

75-
Our experimental evaluation encompasses eight distinct positional encoding methods tested across eleven diverse time series datasets using two transformer architectures.
101+
Our experimental evaluation encompasses ten distinct positional encoding methods tested across eleven diverse time series datasets using two transformer architectures.
76102

77103
### Key Findings
78104

@@ -86,9 +112,9 @@ Our experimental evaluation encompasses eight distinct positional encoding metho
86112
- **Patch Embedding**: More balanced performance among top methods
87113

88114
#### 🏆 Average Rankings
89-
- **SPE**: 1.727 (batch norm), 2.090 (patch embed)
90-
- **TUPE**: 1.909 (batch norm), 2.272 (patch embed)
91-
- **T-PE**: 2.636 (batch norm), 2.363 (patch embed)
115+
- **SPE**: 1.727 (TST), 2.090 (patch embed)
116+
- **TUPE**: 1.909 (TST), 2.272 (patch embed)
117+
- **T-PE**: 2.636 (TST), 2.363 (patch embed)
92118

93119
### Performance Analysis
94120

src/encodings/positional_encodings.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,82 @@ def __init__(self, d_model, dropout=0.1, max_len=1024, scale_factor=1.0):
4545
def forward(self, x):
4646
x = x + self.pe[:, :x.size(1)]
4747
return self.dropout(x)
48+
49+
class RotaryPositionalEncoding(nn.Module):
50+
"""Rotary Position Embedding (RoPE) - used in models like LLaMA"""
51+
def __init__(self, d_model, dropout=0.1, max_len=5000):
52+
super(RotaryPositionalEncoding, self).__init__()
53+
self.dropout = nn.Dropout(p=dropout)
54+
self.d_model = d_model
55+
56+
# Create frequency matrix
57+
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
58+
self.register_buffer('inv_freq', inv_freq)
59+
60+
def forward(self, x):
61+
seq_len = x.shape[1]
62+
device = x.device
63+
64+
# Generate position indices
65+
position = torch.arange(seq_len, device=device).float()
66+
67+
# Create frequency matrix for all positions
68+
freqs = torch.outer(position, self.inv_freq)
69+
freqs = torch.cat([freqs, freqs], dim=-1)
70+
71+
# Apply rotary embedding
72+
cos_freqs = freqs.cos()
73+
sin_freqs = freqs.sin()
74+
75+
# Reshape for broadcasting
76+
cos_freqs = cos_freqs.unsqueeze(0).expand(x.shape[0], -1, -1)
77+
sin_freqs = sin_freqs.unsqueeze(0).expand(x.shape[0], -1, -1)
78+
79+
# Apply rotation
80+
x_rotated = self.apply_rotary_pos_emb(x, cos_freqs, sin_freqs)
81+
return self.dropout(x_rotated)
82+
83+
def apply_rotary_pos_emb(self, x, cos, sin):
84+
# Split the last dimension in half
85+
x1, x2 = x[..., ::2], x[..., 1::2]
86+
87+
# Apply rotation
88+
rotated = torch.zeros_like(x)
89+
rotated[..., ::2] = x1 * cos[..., ::2] - x2 * sin[..., ::2]
90+
rotated[..., 1::2] = x1 * sin[..., 1::2] + x2 * cos[..., 1::2]
91+
92+
return rotated
93+
94+
class RelativePositionalEncoding(nn.Module):
95+
"""Relative Positional Encoding - focuses on relative distances between tokens"""
96+
def __init__(self, d_model, dropout=0.1, max_len=5000):
97+
super(RelativePositionalEncoding, self).__init__()
98+
self.dropout = nn.Dropout(p=dropout)
99+
self.d_model = d_model
100+
self.max_len = max_len
101+
102+
# Learnable relative position embeddings
103+
self.relative_positions = nn.Parameter(
104+
torch.randn(2 * max_len - 1, d_model) * 0.02
105+
)
106+
107+
def forward(self, x):
108+
batch_size, seq_len, d_model = x.shape
109+
110+
# Create relative position matrix
111+
positions = torch.arange(seq_len, device=x.device)
112+
relative_positions = positions[:, None] - positions[None, :]
113+
relative_positions += self.max_len - 1 # Shift to positive indices
114+
115+
# Get relative position embeddings
116+
rel_pos_emb = self.relative_positions[relative_positions]
117+
118+
# Average the relative position embeddings for each position
119+
pos_encoding = rel_pos_emb.mean(dim=1).unsqueeze(0).expand(batch_size, -1, -1)
120+
121+
x = x + pos_encoding
122+
return self.dropout(x)
123+
48124

49125
class AbsolutePositionalEncoding(nn.Module):
50126
def __init__(self, d_model, dropout=0.1, max_len=1024, scale_factor=1.0):

0 commit comments

Comments
 (0)