We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 09e3b60 commit 1c0e09aCopy full SHA for 1c0e09a
chebai/models/ffn.py
@@ -34,12 +34,11 @@ def __init__(
34
self.model = nn.Sequential(*layers)
35
36
if pretrained_checkpoint is not None:
37
- self.model.load_state_dict(
38
- torch.load(
39
- pretrained_checkpoint, map_location=self.device, weights_only=False
40
- )
+ ckpt_file = torch.load(
+ pretrained_checkpoint, map_location=self.device, weights_only=False
41
)
42
- print(f"Loaded pretrained checkpoint from {pretrained_checkpoint}")
+ self.model.load_state_dict(ckpt_file["state_dict"])
+ print(f"Loaded pretrained weights from {pretrained_checkpoint}")
43
44
def _get_prediction_and_labels(self, data, labels, model_output):
45
d = model_output["logits"]
0 commit comments