Skip to content

Commit cf89402

Browse files
authored
[FIX] Exogenous support in TimeXer (#1444)
1 parent 69e4314 commit cf89402

File tree

4 files changed

+19
-22
lines changed

4 files changed

+19
-22
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,6 @@ docs/mintlify/examples
4141
nbs/_extensions
4242
.quarto
4343

44-
*.png
44+
*.png
45+
46+
CLAUDE.md

docs/models.timexer.html.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].v
3636
model = TimeXer(h=12,
3737
input_size=24,
3838
n_series=2,
39-
futr_exog_list=["trend", "month"],
39+
stat_exog_list=['airline1'],
4040
patch_len=12,
4141
hidden_size=128,
4242
n_heads=16,

nbs/docs/capabilities/overview.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"|`TimeMixer` | `AutoTimeMixer` | MLP | Multivariate | Direct | - | \n",
4545
"|`TimeLLM` | - | LLM | Univariate | Direct | - | \n",
4646
"|`TimesNet` | `AutoTimesNet` | CNN | Univariate | Direct | F |\n",
47-
"|`TimeXer` | `AutoTimeXer` | Transformer | Multivariate | Direct | F | \n",
47+
"|`TimeXer` | `AutoTimeXer` | Transformer | Multivariate | Direct | H/S | \n",
4848
"|`TSMixer` | `AutoTSMixer` | MLP | Multivariate | Direct | - | \n",
4949
"|`TSMixerx` | `AutoTSMixerx` | MLP | Multivariate | Direct | F/H/S | \n",
5050
"|`VanillaTransformer` | `AutoVanillaTransformer` | Transformer | Univariate | Direct | F | \n",
@@ -67,7 +67,11 @@
6767
"source": []
6868
}
6969
],
70-
"metadata": {},
70+
"metadata": {
71+
"language_info": {
72+
"name": "python"
73+
}
74+
},
7175
"nbformat": 4,
7276
"nbformat_minor": 2
7377
}

neuralforecast/models/timexer.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class TimeXer(BaseModel):
183183
"""
184184

185185
# Class attributes
186-
EXOGENOUS_FUTR = True
186+
EXOGENOUS_FUTR = False
187187
EXOGENOUS_HIST = True
188188
EXOGENOUS_STAT = True
189189
MULTIVARIATE = True # If the model produces multivariate forecasts (True) or univariate (False)
@@ -370,41 +370,32 @@ def forecast(self, x_enc, x_mark_enc):
370370
def forward(self, windows_batch):
371371
insample_y = windows_batch["insample_y"] # [B, L, N]
372372
hist_exog = windows_batch["hist_exog"] # [B, X, L, N]
373-
futr_exog = windows_batch["futr_exog"] # [B, F, L+h, N]
374373
stat_exog = windows_batch["stat_exog"] # [N, S]
375-
376-
B, L, N = insample_y.shape
377-
374+
375+
B, L, _ = insample_y.shape
376+
378377
# Build exogenous features for the input sequence
379378
exog_list = []
380-
379+
381380
# Add historical exogenous
382381
if self.hist_exog_size > 0:
383382
# hist_exog: [B, X, L, N] -> [B, L, X, N] -> [B, L, X*N]
384383
hist_exog_input = hist_exog.permute(0, 2, 1, 3) # [B, L, X, N]
385384
hist_exog_input = hist_exog_input.reshape(B, L, -1) # [B, L, X*N]
386385
exog_list.append(hist_exog_input)
387-
388-
# Add future exogenous
389-
if self.futr_exog_size > 0:
390-
# Take only the input sequence part of future exogenous
391-
futr_exog_input = futr_exog[:, :, :L, :] # [B, F, L, N]
392-
futr_exog_input = futr_exog_input.permute(0, 2, 1, 3) # [B, L, F, N]
393-
futr_exog_input = futr_exog_input.reshape(B, L, -1) # [B, L, F*N]
394-
exog_list.append(futr_exog_input)
395-
386+
396387
# Combine all exogenous features
397388
if len(exog_list) > 0:
398-
x_mark_enc = torch.cat(exog_list, dim=-1) # [B, L, (X+F)*N]
389+
x_mark_enc = torch.cat(exog_list, dim=-1) # [B, L, X*N]
399390
else:
400391
x_mark_enc = None
401-
392+
402393
# Add static exogenous to the cross-attention context
403394
if self.stat_exog_size > 0 and x_mark_enc is not None:
404395
# stat_exog: [N, S] -> [B, L, N*S]
405396
stat_exog_expanded = stat_exog.reshape(-1).unsqueeze(0).unsqueeze(0) # [1, 1, N*S]
406397
stat_exog_expanded = stat_exog_expanded.repeat(B, L, 1) # [B, L, N*S]
407-
x_mark_enc = torch.cat([x_mark_enc, stat_exog_expanded], dim=-1) # [B, L, (X+F)*N + N*S]
398+
x_mark_enc = torch.cat([x_mark_enc, stat_exog_expanded], dim=-1) # [B, L, X*N + N*S]
408399
elif self.stat_exog_size > 0:
409400
# Only static exogenous available
410401
stat_exog_expanded = stat_exog.reshape(-1).unsqueeze(0).unsqueeze(0)

0 commit comments

Comments
 (0)