@@ -296,7 +296,7 @@ def state_dict(self) -> dict:
296296 the model was trained with.
297297 """
298298
299- return {
299+ state_dict = {
300300 "model" : self .model .state_dict (),
301301 "optimizer" : self .optimizer .state_dict (),
302302 "loss" : torch .tensor (self .history ),
@@ -306,6 +306,13 @@ def state_dict(self) -> dict:
306306 "log" : self .log ,
307307 }
308308
309+ if hasattr (self , "n_features" ):
310+ state_dict ["n_features" ] = self .n_features
311+ if hasattr (self , "num_sessions" ):
312+ state_dict ["num_sessions" ] = self .num_sessions
313+
314+ return state_dict
315+
309316 def load_state_dict (self , state_dict : dict , strict : bool = True ):
310317 """Update the solver state with the given state_dict.
311318
@@ -343,6 +350,12 @@ def _get(key):
343350 if _contains ("log" ):
344351 self .log = _get ("log" )
345352
353+ # Not defined if the model was saved before being fitted.
354+ if "n_features" in state_dict :
355+ self .n_features = _get ("n_features" )
356+ if "num_sessions" in state_dict :
357+ self .num_sessions = _get ("num_sessions" )
358+
346359 @property
347360 def num_parameters (self ) -> int :
348361 """Total number of parameters in the encoder and criterion."""
@@ -633,11 +646,6 @@ def load(self, logdir, filename="checkpoint.pth"):
633646 checkpoint = torch .load (savepath , map_location = self .device )
634647 self .load_state_dict (checkpoint , strict = True )
635648
636- n_features = self .n_features
637- self .n_features = ([
638- session_n_features for session_n_features in n_features
639- ] if isinstance (n_features , list ) else n_features )
640-
641649 def save (self , logdir , filename = "checkpoint.pth" ):
642650 """Save the model and optimizer params.
643651
0 commit comments