-
Notifications
You must be signed in to change notification settings - Fork 48
Open
Description
I'm trying to load saved checkpoint and for the predictions on test dataset but i'm having issue.It would be great if someone help me fix this issue and how can we yaml parameter filr inoredr to load file.
model = TransformerEstimator(
freq=freq,
prediction_length=prediction_length,
context_length=prediction_length*2,
num_feat_static_cat=0,
cardinality=[1],
embedding_dimension=[3],
nhead=8,
num_encoder_layers=2,
num_decoder_layers=1,
dim_feedforward=2048,
activation="gelu",
scaling=True,
batch_size=32,
num_batches_per_epoch=20,
#distr_output=ImplicitQuantileNetworkOutput("positive"),
#loss=QuantileLoss(),
trainer_kwargs=dict( max_epochs=30, logger=CSVLogger(".", "lightning_logs/")),
)
checkpoint = torch.load('/Workspace/Repo/pytorch-transformer-ts/transformer/lightning_logs/version_4/checkpoints/epoch=7-step=160.ckpt')
model.load_state_dict(checkpoint['state_dict'])
AttributeError: 'TransformerEstimator' object has no attribute 'load_state_dict'
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels