Skip to content

Commit c6179ad

Browse files
committed
Fix save/load
1 parent 3e91459 commit c6179ad

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

cebra/solver/base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -668,13 +668,12 @@ def load(self, logdir, filename="checkpoint.pth"):
668668
checkpoint = torch.load(savepath, map_location=self.device)
669669
self.load_state_dict(checkpoint, strict=True)
670670

671-
if hasattr(self.model, "n_features"):
672-
n_features = self.model.n_features
673-
self.n_features = ([
674-
session_n_features for session_n_features in n_features
675-
] if isinstance(n_features, list) else n_features)
671+
n_features = self.n_features
672+
self.n_features = ([
673+
session_n_features for session_n_features in n_features
674+
] if isinstance(n_features, list) else n_features)
676675

677-
def save(self, logdir, filename="checkpoint_last.pth"):
676+
def save(self, logdir, filename="checkpoint.pth"):
678677
"""Save the model and optimizer params.
679678
680679
Args:

tests/test_solver.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# limitations under the License.
2121
#
2222
import copy
23+
import itertools
2324
import tempfile
2425

2526
import numpy as np
@@ -189,6 +190,12 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
189190
for param in solver.parameters():
190191
assert isinstance(param, torch.Tensor)
191192

193+
fitted_solver = copy.deepcopy(solver)
194+
with tempfile.TemporaryDirectory() as temp_dir:
195+
solver.save(temp_dir)
196+
solver.load(temp_dir)
197+
_assert_equal(fitted_solver, solver)
198+
192199

193200
embedding = solver.transform(X)
194201
assert isinstance(embedding, torch.Tensor)
@@ -380,6 +387,12 @@ def test_multi_session(data_name, loader_initfunc, model_architecture,
380387
for param in solver.parameters():
381388
assert isinstance(param, torch.Tensor)
382389

390+
fitted_solver = copy.deepcopy(solver)
391+
with tempfile.TemporaryDirectory() as temp_dir:
392+
solver.save(temp_dir)
393+
solver.load(temp_dir)
394+
_assert_equal(fitted_solver, solver)
395+
383396

384397
@pytest.mark.parametrize(
385398
"inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output",

0 commit comments

Comments
 (0)