-
Notifications
You must be signed in to change notification settings - Fork 485
Description
What happened + What you expected to happen
I trained some models in neuralforecast library and I save them each time with nf.save. Everything works with nf.load, but then I got an issue with TimeXer, where it saved fine, but it couldn't load. I got this error:
TypeError: neuralforecast.common._base_model.BaseModel.__init__() got multiple values for keyword argument 'exclude_insample_y'
I tried other models on AirPassengers example and they work there, but this one fails. So, there is something wrong in way of saving and loading TimeXer. What exactly is it, I don't really know, but I would like to learn how to save and load it correctly.
Versions / Dependencies
I am using Google Cloud notebook on Vertex AI.
Reproduction script
`from neuralforecast import NeuralForecast
from neuralforecast.models import TimeXer
from neuralforecast.losses.pytorch import MSE, MAE
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic, augment_calendar_df
AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')
Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test
model = TimeXer(h=12,
input_size=24,
n_series=2,
futr_exog_list=["trend", "month"],
patch_len=12,
hidden_size=128,
n_heads=16,
e_layers=2,
d_ff=256,
factor=1,
dropout=0.1,
use_norm=True,
loss=MSE(),
valid_loss=MAE(),
early_stop_patience_steps=3,
batch_size=32)
fcst = NeuralForecast(models=[model], freq='ME')
fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)
Define save path
save_path = f"airpassengers_2/"
os.makedirs(save_path)
Save the trained model and configuration
fcst.save(path=save_path, save_dataset=True, overwrite=True, model_index=None) # this works
print(f"Model saved to {save_path}")
nf = NeuralForecast.load(save_path) # this fails with previous error message
`
Issue Severity
High: It blocks me from completing my task.