@@ -183,7 +183,7 @@ class TimeXer(BaseModel):
183183 """
184184
185185 # Class attributes
186- EXOGENOUS_FUTR = True
186+ EXOGENOUS_FUTR = False
187187 EXOGENOUS_HIST = True
188188 EXOGENOUS_STAT = True
189189 MULTIVARIATE = True # If the model produces multivariate forecasts (True) or univariate (False)
@@ -370,41 +370,32 @@ def forecast(self, x_enc, x_mark_enc):
370370 def forward (self , windows_batch ):
371371 insample_y = windows_batch ["insample_y" ] # [B, L, N]
372372 hist_exog = windows_batch ["hist_exog" ] # [B, X, L, N]
373- futr_exog = windows_batch ["futr_exog" ] # [B, F, L+h, N]
374373 stat_exog = windows_batch ["stat_exog" ] # [N, S]
375-
376- B , L , N = insample_y .shape
377-
374+
375+ B , L , _ = insample_y .shape
376+
378377 # Build exogenous features for the input sequence
379378 exog_list = []
380-
379+
381380 # Add historical exogenous
382381 if self .hist_exog_size > 0 :
383382 # hist_exog: [B, X, L, N] -> [B, L, X, N] -> [B, L, X*N]
384383 hist_exog_input = hist_exog .permute (0 , 2 , 1 , 3 ) # [B, L, X, N]
385384 hist_exog_input = hist_exog_input .reshape (B , L , - 1 ) # [B, L, X*N]
386385 exog_list .append (hist_exog_input )
387-
388- # Add future exogenous
389- if self .futr_exog_size > 0 :
390- # Take only the input sequence part of future exogenous
391- futr_exog_input = futr_exog [:, :, :L , :] # [B, F, L, N]
392- futr_exog_input = futr_exog_input .permute (0 , 2 , 1 , 3 ) # [B, L, F, N]
393- futr_exog_input = futr_exog_input .reshape (B , L , - 1 ) # [B, L, F*N]
394- exog_list .append (futr_exog_input )
395-
386+
396387 # Combine all exogenous features
397388 if len (exog_list ) > 0 :
398- x_mark_enc = torch .cat (exog_list , dim = - 1 ) # [B, L, (X+F) *N]
389+ x_mark_enc = torch .cat (exog_list , dim = - 1 ) # [B, L, X *N]
399390 else :
400391 x_mark_enc = None
401-
392+
402393 # Add static exogenous to the cross-attention context
403394 if self .stat_exog_size > 0 and x_mark_enc is not None :
404395 # stat_exog: [N, S] -> [B, L, N*S]
405396 stat_exog_expanded = stat_exog .reshape (- 1 ).unsqueeze (0 ).unsqueeze (0 ) # [1, 1, N*S]
406397 stat_exog_expanded = stat_exog_expanded .repeat (B , L , 1 ) # [B, L, N*S]
407- x_mark_enc = torch .cat ([x_mark_enc , stat_exog_expanded ], dim = - 1 ) # [B, L, (X+F) *N + N*S]
398+ x_mark_enc = torch .cat ([x_mark_enc , stat_exog_expanded ], dim = - 1 ) # [B, L, X *N + N*S]
408399 elif self .stat_exog_size > 0 :
409400 # Only static exogenous available
410401 stat_exog_expanded = stat_exog .reshape (- 1 ).unsqueeze (0 ).unsqueeze (0 )
0 commit comments