Skip to content

Commit 1c0e09a

Browse files
committed
fix weights error
1 parent 09e3b60 commit 1c0e09a

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

chebai/models/ffn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@ def __init__(
3434
self.model = nn.Sequential(*layers)
3535

3636
if pretrained_checkpoint is not None:
37-
self.model.load_state_dict(
38-
torch.load(
39-
pretrained_checkpoint, map_location=self.device, weights_only=False
40-
)
37+
ckpt_file = torch.load(
38+
pretrained_checkpoint, map_location=self.device, weights_only=False
4139
)
42-
print(f"Loaded pretrained checkpoint from {pretrained_checkpoint}")
40+
self.model.load_state_dict(ckpt_file["state_dict"])
41+
print(f"Loaded pretrained weights from {pretrained_checkpoint}")
4342

4443
def _get_prediction_and_labels(self, data, labels, model_output):
4544
d = model_output["logits"]

0 commit comments

Comments
 (0)