Skip to content

Commit 0add04b

Browse files
committed
filter keys prefix from state dicts
1 parent 1c0e09a commit 0add04b

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

chebai/models/ffn.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from chebai.models import ChebaiBaseNet
77

8+
from .electra import filter_dict
9+
810

911
class FFN(ChebaiBaseNet):
1012
# Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139
@@ -16,6 +18,7 @@ def __init__(
1618
],
1719
use_adam_optimizer: bool = False,
1820
pretrained_checkpoint: Optional[str] = None,
21+
load_prefix: Optional[str] = "model.",
1922
**kwargs,
2023
):
2124
super().__init__(**kwargs)
@@ -37,7 +40,11 @@ def __init__(
3740
ckpt_file = torch.load(
3841
pretrained_checkpoint, map_location=self.device, weights_only=False
3942
)
40-
self.model.load_state_dict(ckpt_file["state_dict"])
43+
if load_prefix is not None:
44+
state_dict = filter_dict(ckpt_file["state_dict"], load_prefix)
45+
else:
46+
state_dict = ckpt_file["state_dict"]
47+
self.model.load_state_dict(state_dict)
4148
print(f"Loaded pretrained weights from {pretrained_checkpoint}")
4249

4350
def _get_prediction_and_labels(self, data, labels, model_output):

0 commit comments

Comments
 (0)