File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change 55
66from chebai .models import ChebaiBaseNet
77
8+ from .electra import filter_dict
9+
810
911class FFN (ChebaiBaseNet ):
1012 # Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139
@@ -16,6 +18,7 @@ def __init__(
1618 ],
1719 use_adam_optimizer : bool = False ,
1820 pretrained_checkpoint : Optional [str ] = None ,
21+ load_prefix : Optional [str ] = "model." ,
1922 ** kwargs ,
2023 ):
2124 super ().__init__ (** kwargs )
@@ -37,7 +40,11 @@ def __init__(
3740 ckpt_file = torch .load (
3841 pretrained_checkpoint , map_location = self .device , weights_only = False
3942 )
40- self .model .load_state_dict (ckpt_file ["state_dict" ])
43+ if load_prefix is not None :
44+ state_dict = filter_dict (ckpt_file ["state_dict" ], load_prefix )
45+ else :
46+ state_dict = ckpt_file ["state_dict" ]
47+ self .model .load_state_dict (state_dict )
4148 print (f"Loaded pretrained weights from { pretrained_checkpoint } " )
4249
4350 def _get_prediction_and_labels (self , data , labels , model_output ):
You can’t perform that action at this time.
0 commit comments