Skip to content

Commit 09e3b60

Browse files
committed
pretrained weights for mlp
1 parent 1c629ff commit 09e3b60

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

chebai/models/ffn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(
1515
1024,
1616
],
1717
use_adam_optimizer: bool = False,
18+
pretrained_checkpoint: Optional[str] = None,
1819
**kwargs,
1920
):
2021
super().__init__(**kwargs)
@@ -32,6 +33,14 @@ def __init__(
3233
layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim))
3334
self.model = nn.Sequential(*layers)
3435

36+
if pretrained_checkpoint is not None:
37+
self.model.load_state_dict(
38+
torch.load(
39+
pretrained_checkpoint, map_location=self.device, weights_only=False
40+
)
41+
)
42+
print(f"Loaded pretrained checkpoint from {pretrained_checkpoint}")
43+
3544
def _get_prediction_and_labels(self, data, labels, model_output):
3645
d = model_output["logits"]
3746
loss_kwargs = data.get("loss_kwargs", dict())

0 commit comments

Comments
 (0)