Skip to content

Commit c9c53de

Browse files
marcopeixelephaint
andauthored
[FEAT] Add XLinear (#1445)
Co-authored-by: elephaint <osprangers@gmail.com>
1 parent cf89402 commit c9c53de

File tree

8 files changed

+526
-3
lines changed

8 files changed

+526
-3
lines changed

docs/mintlify/docs.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@
120120
"models.timexer.html",
121121
"models.tsmixer.html",
122122
"models.tsmixerx.html",
123-
"models.vanillatransformer.html"
123+
"models.vanillatransformer.html",
124+
"models.xlinear.html"
124125
]
125126
},
126127
"models.html",
686 KB
Loading

docs/models.xlinear.html.md

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
---
2+
description: >-
3+
XLinear: A MLP-based model for multivariate forecasting with exogenous features.
4+
output-file: models.xlinear.html
5+
title: XLinear
6+
---
7+
8+
XLinear is a MLP-based model for multivariate time series forecasting
9+
that uses gating mechanisms for temporal and cross-channel interactions.
10+
The architecture consists of temporal gating with a global token to capture
11+
global temporal patterns, followed by cross-channel gating to model
12+
dependencies between different time series.
13+
14+
**References**
15+
16+
- [Xinyang, C., et al. "XLinear: A Lightweight and Accurate MLP-Based Model for Long-Term Time Series Forecasting with Exogenous Inputs"](https://arxiv.org/abs/2601.09237)
17+
18+
![Figure 1. Architecture of XLinear](imgs_models/xlinear.png)
19+
*Figure 1. Architecture of XLinear*
20+
21+
## XLinear
22+
23+
::: neuralforecast.models.xlinear.XLinear
24+
options:
25+
members:
26+
- fit
27+
- predict
28+
heading_level: 3
29+
30+
### Usage Example
31+
32+
33+
```python
34+
import pandas as pd
35+
import matplotlib.pyplot as plt
36+
37+
from neuralforecast import NeuralForecast
38+
from neuralforecast.models import XLinear
39+
from neuralforecast.losses.pytorch import MAE
40+
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic
41+
42+
Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train
43+
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test
44+
45+
model = XLinear(h=12,
46+
input_size=24,
47+
n_series=2,
48+
stat_exog_list=['airline1'],
49+
hist_exog_list=["y_[lag12]"],
50+
futr_exog_list=['trend'],
51+
loss = MAE(),
52+
scaler_type='robust',
53+
learning_rate=1e-3,
54+
max_steps=200,
55+
val_check_steps=10,
56+
early_stop_patience_steps=2)
57+
58+
fcst = NeuralForecast(
59+
models=[model],
60+
freq='ME'
61+
)
62+
fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)
63+
forecasts = fcst.predict(futr_df=Y_test_df)
64+
65+
# Plot predictions
66+
Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])
67+
plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)
68+
plot_df = pd.concat([Y_train_df, plot_df])
69+
70+
plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)
71+
plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')
72+
plt.plot(plot_df['ds'], plot_df['XLinear'], c='blue', label='median')
73+
plt.grid()
74+
plt.legend()
75+
plt.plot()
76+
```

nbs/docs/capabilities/overview.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"|`TSMixer` | `AutoTSMixer` | MLP | Multivariate | Direct | - | \n",
4949
"|`TSMixerx` | `AutoTSMixerx` | MLP | Multivariate | Direct | F/H/S | \n",
5050
"|`VanillaTransformer` | `AutoVanillaTransformer` | Transformer | Univariate | Direct | F | \n",
51+
"|`XLinear` | `AutoXLinear` | MLP | Multivariate| Direct | F/H/S |\n",
5152
"|`xLSTM` | `AutoxLSTM` | mLSTM | Univariate | Direct | F/H/S | \n",
5253
"\n",
5354
"\n",

neuralforecast/auto.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
'AutoNBEATS', 'AutoNBEATSx', 'AutoNHITS', 'AutoDLinear', 'AutoNLinear', 'AutoTiDE', 'AutoDeepNPTS',
66
'AutoKAN', 'AutoTFT', 'AutoVanillaTransformer', 'AutoInformer', 'AutoAutoformer', 'AutoFEDformer',
77
'AutoPatchTST', 'AutoiTransformer', 'AutoTimeXer', 'AutoTimesNet', 'AutoStemGNN', 'AutoHINT', 'AutoTSMixer',
8-
'AutoTSMixerx', 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer', 'AutoRMoK']
8+
'AutoTSMixerx', 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer', 'AutoRMoK', 'AutoXLinear']
99

1010

1111
from os import cpu_count
@@ -50,6 +50,7 @@
5050
from .models.tsmixerx import TSMixerx
5151
from .models.vanillatransformer import VanillaTransformer
5252
from .models.xlstm import xLSTM
53+
from .models.xlinear import XLinear
5354

5455

5556
class AutoRNN(BaseAuto):
@@ -2556,3 +2557,88 @@ def get_default_config(cls, h, backend, n_series):
25562557
config = cls._ray_config_to_optuna(config)
25572558

25582559
return config
2560+
2561+
2562+
class AutoXLinear(BaseAuto):
2563+
2564+
default_config = {
2565+
"input_size_multiplier": [1, 2, 3, 4, 5],
2566+
"h": None,
2567+
"n_series": None,
2568+
"hidden_size": tune.choice([64, 128, 256]),
2569+
"use_norm": tune.choice([True, False]),
2570+
"learning_rate": tune.loguniform(1e-4, 1e-1),
2571+
"scaler_type": tune.choice([None, "robust", "standard"]),
2572+
"max_steps": tune.choice([500, 1000]),
2573+
"batch_size": tune.choice([32, 64, 128, 256]),
2574+
"loss": None,
2575+
"random_seed": tune.randint(1, 20),
2576+
}
2577+
2578+
def __init__(
2579+
self,
2580+
h,
2581+
n_series,
2582+
loss=MAE(),
2583+
valid_loss=None,
2584+
config=None,
2585+
search_alg=BasicVariantGenerator(random_state=1),
2586+
num_samples=10,
2587+
refit_with_val=False,
2588+
cpus=cpu_count(),
2589+
gpus=torch.cuda.device_count(),
2590+
verbose=False,
2591+
alias=None,
2592+
backend="ray",
2593+
callbacks=None,
2594+
):
2595+
2596+
# Define search space, input/output sizes
2597+
if config is None:
2598+
config = self.get_default_config(h=h, backend=backend, n_series=n_series)
2599+
2600+
# Always use n_series from parameters, raise exception with Optuna because we can't enforce it
2601+
if backend == "ray":
2602+
config["n_series"] = n_series
2603+
elif backend == "optuna":
2604+
mock_trial = MockTrial()
2605+
if (
2606+
"n_series" in config(mock_trial)
2607+
and config(mock_trial)["n_series"] != n_series
2608+
) or ("n_series" not in config(mock_trial)):
2609+
raise Exception(f"config needs 'n_series': {n_series}")
2610+
2611+
super(AutoXLinear, self).__init__(
2612+
cls_model=XLinear,
2613+
h=h,
2614+
loss=loss,
2615+
valid_loss=valid_loss,
2616+
config=config,
2617+
search_alg=search_alg,
2618+
num_samples=num_samples,
2619+
refit_with_val=refit_with_val,
2620+
cpus=cpus,
2621+
gpus=gpus,
2622+
verbose=verbose,
2623+
alias=alias,
2624+
backend=backend,
2625+
callbacks=callbacks,
2626+
)
2627+
2628+
@classmethod
2629+
def get_default_config(cls, h, backend, n_series):
2630+
config = cls.default_config.copy()
2631+
config["input_size"] = tune.choice(
2632+
[h * x for x in config["input_size_multiplier"]]
2633+
)
2634+
2635+
# Rolling windows with step_size=1 or step_size=h
2636+
# See `BaseWindows` and `BaseRNN`'s create_windows
2637+
config["step_size"] = tune.choice([1, h])
2638+
del config["input_size_multiplier"]
2639+
if backend == "optuna":
2640+
# Always use n_series from parameters
2641+
config["n_series"] = n_series
2642+
config = cls._ray_config_to_optuna(config)
2643+
2644+
return config

neuralforecast/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
'TFT', 'VanillaTransformer', 'Informer', 'Autoformer', 'PatchTST', 'FEDformer',
44
'StemGNN', 'HINT', 'TimesNet', 'TimeLLM', 'TSMixer', 'TSMixerx', 'MLPMultivariate',
55
'iTransformer', 'BiTCN', 'TiDE', 'DeepNPTS', 'SOFTS', 'TimeMixer', 'KAN', 'RMoK',
6-
'TimeXer', 'xLSTM'
6+
'TimeXer', 'xLSTM', 'XLinear'
77
]
88

99
from .rnn import RNN
@@ -41,3 +41,4 @@
4141
from .rmok import RMoK
4242
from .timexer import TimeXer
4343
from .xlstm import xLSTM
44+
from .xlinear import XLinear

0 commit comments

Comments
 (0)