File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -15,6 +15,7 @@ def __init__(
1515 1024 ,
1616 ],
1717 use_adam_optimizer : bool = False ,
18+ pretrained_checkpoint : Optional [str ] = None ,
1819 ** kwargs ,
1920 ):
2021 super ().__init__ (** kwargs )
@@ -32,6 +33,14 @@ def __init__(
3233 layers .append (torch .nn .Linear (current_layer_input_size , self .out_dim ))
3334 self .model = nn .Sequential (* layers )
3435
36+ if pretrained_checkpoint is not None :
37+ self .model .load_state_dict (
38+ torch .load (
39+ pretrained_checkpoint , map_location = self .device , weights_only = False
40+ )
41+ )
42+ print (f"Loaded pretrained checkpoint from { pretrained_checkpoint } " )
43+
3544 def _get_prediction_and_labels (self , data , labels , model_output ):
3645 d = model_output ["logits" ]
3746 loss_kwargs = data .get ("loss_kwargs" , dict ())
You can’t perform that action at this time.
0 commit comments