Skip to content

TimeXer cannot load after saving #1303

@alemkakti

Description

@alemkakti

What happened + What you expected to happen

I trained some models in neuralforecast library and I save them each time with nf.save. Everything works with nf.load, but then I got an issue with TimeXer, where it saved fine, but it couldn't load. I got this error:
TypeError: neuralforecast.common._base_model.BaseModel.__init__() got multiple values for keyword argument 'exclude_insample_y'

I tried other models on AirPassengers example and they work there, but this one fails. So, there is something wrong in way of saving and loading TimeXer. What exactly is it, I don't really know, but I would like to learn how to save and load it correctly.

Versions / Dependencies

I am using Google Cloud notebook on Vertex AI.

Reproduction script

`from neuralforecast import NeuralForecast
from neuralforecast.models import TimeXer
from neuralforecast.losses.pytorch import MSE, MAE
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic, augment_calendar_df

AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')

Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test

model = TimeXer(h=12,
input_size=24,
n_series=2,
futr_exog_list=["trend", "month"],
patch_len=12,
hidden_size=128,
n_heads=16,
e_layers=2,
d_ff=256,
factor=1,
dropout=0.1,
use_norm=True,
loss=MSE(),
valid_loss=MAE(),
early_stop_patience_steps=3,
batch_size=32)

fcst = NeuralForecast(models=[model], freq='ME')
fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)

Define save path

save_path = f"airpassengers_2/"
os.makedirs(save_path)

Save the trained model and configuration

fcst.save(path=save_path, save_dataset=True, overwrite=True, model_index=None) # this works
print(f"Model saved to {save_path}")

nf = NeuralForecast.load(save_path) # this fails with previous error message
`

Issue Severity

High: It blocks me from completing my task.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions