Skip to content

Commit 0c693dd

Browse files
CeliaBenquetstes
authored andcommitted
Make save/load cleaner
1 parent d08e400 commit 0c693dd

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

cebra/solver/base.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def state_dict(self) -> dict:
296296
the model was trained with.
297297
"""
298298

299-
return {
299+
state_dict = {
300300
"model": self.model.state_dict(),
301301
"optimizer": self.optimizer.state_dict(),
302302
"loss": torch.tensor(self.history),
@@ -306,6 +306,13 @@ def state_dict(self) -> dict:
306306
"log": self.log,
307307
}
308308

309+
if hasattr(self, "n_features"):
310+
state_dict["n_features"] = self.n_features
311+
if hasattr(self, "num_sessions"):
312+
state_dict["num_sessions"] = self.num_sessions
313+
314+
return state_dict
315+
309316
def load_state_dict(self, state_dict: dict, strict: bool = True):
310317
"""Update the solver state with the given state_dict.
311318
@@ -343,6 +350,12 @@ def _get(key):
343350
if _contains("log"):
344351
self.log = _get("log")
345352

353+
# Not defined if the model was saved before being fitted.
354+
if "n_features" in state_dict:
355+
self.n_features = _get("n_features")
356+
if "num_sessions" in state_dict:
357+
self.num_sessions = _get("num_sessions")
358+
346359
@property
347360
def num_parameters(self) -> int:
348361
"""Total number of parameters in the encoder and criterion."""
@@ -633,11 +646,6 @@ def load(self, logdir, filename="checkpoint.pth"):
633646
checkpoint = torch.load(savepath, map_location=self.device)
634647
self.load_state_dict(checkpoint, strict=True)
635648

636-
n_features = self.n_features
637-
self.n_features = ([
638-
session_n_features for session_n_features in n_features
639-
] if isinstance(n_features, list) else n_features)
640-
641649
def save(self, logdir, filename="checkpoint.pth"):
642650
"""Save the model and optimizer params.
643651

0 commit comments

Comments
 (0)