Skip to content

Commit 8e5f933

Browse files
CeliaBenquetstes
authored andcommitted
Fix save/load
1 parent 9db3e37 commit 8e5f933

File tree

3 files changed

+72
-6
lines changed

3 files changed

+72
-6
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,11 @@ def load(cls,
14171417
else:
14181418
cebra_ = _check_type_checkpoint(checkpoint)
14191419

1420+
n_features = cebra_.n_features_
1421+
cebra_.solver_.n_features = ([
1422+
session_n_features for session_n_features in n_features
1423+
] if isinstance(n_features, list) else n_features)
1424+
14201425
return cebra_
14211426

14221427
def to(self, device: Union[str, torch.device]):

cebra/solver/base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,13 +633,12 @@ def load(self, logdir, filename="checkpoint.pth"):
633633
checkpoint = torch.load(savepath, map_location=self.device)
634634
self.load_state_dict(checkpoint, strict=True)
635635

636-
if hasattr(self.model, "n_features"):
637-
n_features = self.model.n_features
638-
self.n_features = ([
639-
session_n_features for session_n_features in n_features
640-
] if isinstance(n_features, list) else n_features)
636+
n_features = self.n_features
637+
self.n_features = ([
638+
session_n_features for session_n_features in n_features
639+
] if isinstance(n_features, list) else n_features)
641640

642-
def save(self, logdir, filename="checkpoint_last.pth"):
641+
def save(self, logdir, filename="checkpoint.pth"):
643642
"""Save the model and optimizer params.
644643
645644
Args:

tests/test_solver.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22+
import copy
2223
import itertools
24+
import tempfile
2325

2426
import numpy as np
2527
import pytest
@@ -91,6 +93,48 @@ def _make_model(dataset, model_architecture="offset10-model"):
9193
# )
9294

9395

96+
def _assert_same_state_dict(first, second):
97+
assert first.keys() == second.keys()
98+
for key in first:
99+
if isinstance(first[key], torch.Tensor):
100+
assert torch.allclose(first[key], second[key]), key
101+
elif isinstance(first[key], dict):
102+
_assert_same_state_dict(first[key], second[key]), key
103+
else:
104+
assert first[key] == second[key]
105+
106+
107+
def check_if_fit(model):
108+
"""Check if a model was already fit.
109+
110+
Args:
111+
model: The model to check.
112+
113+
Returns:
114+
True if the model was already fit.
115+
"""
116+
return hasattr(model, "n_features_")
117+
118+
119+
def _assert_equal(original_solver, loaded_solver):
120+
for k in original_solver.model.state_dict():
121+
assert original_solver.model.state_dict()[k].all(
122+
) == loaded_solver.model.state_dict()[k].all()
123+
assert check_if_fit(loaded_solver) == check_if_fit(original_solver)
124+
125+
if check_if_fit(loaded_solver):
126+
_assert_same_state_dict(original_solver.state_dict_,
127+
loaded_solver.state_dict_)
128+
X = np.random.normal(0, 1, (100, 1))
129+
130+
if loaded_solver.num_sessions is not None:
131+
assert np.allclose(loaded_solver.transform(X, session_id=0),
132+
original_solver.transform(X, session_id=0))
133+
else:
134+
assert np.allclose(loaded_solver.transform(X),
135+
original_solver.transform(X))
136+
137+
94138
@pytest.mark.parametrize(
95139
"data_name, loader_initfunc, model_architecture, solver_initfunc",
96140
single_session_tests)
@@ -144,6 +188,12 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
144188
for param in solver.parameters():
145189
assert isinstance(param, torch.Tensor)
146190

191+
fitted_solver = copy.deepcopy(solver)
192+
with tempfile.TemporaryDirectory() as temp_dir:
193+
solver.save(temp_dir)
194+
solver.load(temp_dir)
195+
_assert_equal(fitted_solver, solver)
196+
147197

148198
@pytest.mark.parametrize(
149199
"data_name, loader_initfunc, model_architecture, solver_initfunc",
@@ -225,6 +275,12 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture,
225275
for param in solver.parameters():
226276
assert isinstance(param, torch.Tensor)
227277

278+
fitted_solver = copy.deepcopy(solver)
279+
with tempfile.TemporaryDirectory() as temp_dir:
280+
solver.save(temp_dir)
281+
solver.load(temp_dir)
282+
_assert_equal(fitted_solver, solver)
283+
228284

229285
@pytest.mark.parametrize(
230286
"data_name, loader_initfunc, model_architecture, solver_initfunc",
@@ -302,6 +358,12 @@ def test_multi_session(data_name, loader_initfunc, model_architecture,
302358
for param in solver.parameters():
303359
assert isinstance(param, torch.Tensor)
304360

361+
fitted_solver = copy.deepcopy(solver)
362+
with tempfile.TemporaryDirectory() as temp_dir:
363+
solver.save(temp_dir)
364+
solver.load(temp_dir)
365+
_assert_equal(fitted_solver, solver)
366+
305367

306368
@pytest.mark.parametrize(
307369
"inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output",

0 commit comments

Comments
 (0)