@@ -175,4 +175,32 @@ def forward(self, x):
175175 x = x .permute (0 , 2 , 1 ) # Change shape to (batch_size, in_features, seq_len)
176176 x = self .conv (x ) # Apply convolution
177177 x = x .permute (0 , 2 , 1 ) # Change shape back to (batch_size, seq_len, in_features)
178- return x
178+ return x
179+
180+
181+ # Temporal Positional Encoding (T-PE)
182+ class TemporalPositionalEncoding (nn .Module ):
183+ def __init__ (self , d_model , max_len = 896 ): # Assuming 896 timesteps
184+ super (TemporalPositionalEncoding , self ).__init__ ()
185+ pe = torch .zeros (max_len , d_model )
186+ position = torch .arange (0 , max_len , dtype = torch .float ).unsqueeze (1 )
187+ div_term = torch .exp (torch .arange (0 , d_model , 2 ).float () * (- math .log (10000.0 ) / d_model ))
188+
189+ pe [:, 0 ::2 ] = torch .sin (position * div_term )
190+ pe [:, 1 ::2 ] = torch .cos (position * div_term )
191+ self .register_buffer ('pe' , pe )
192+
193+ def forward (self , x ):
194+ seq_len = x .size (1 )
195+ return self .pe [:seq_len , :].unsqueeze (0 ).expand (x .size (0 ), - 1 , - 1 )
196+
197+
198+ # Variable Positional Encoding for handling multivariate data
199+ class VariablePositionalEncoding (nn .Module ):
200+ def __init__ (self , d_model , num_variables ):
201+ super (VariablePositionalEncoding , self ).__init__ ()
202+ self .variable_embedding = nn .Embedding (num_variables , d_model )
203+
204+ def forward (self , x , variable_idx ):
205+ variable_embed = self .variable_embedding (variable_idx )
206+ return x + variable_embed .unsqueeze (0 )
0 commit comments