|
19 | 19 | # See the License for the specific language governing permissions and |
20 | 20 | # limitations under the License. |
21 | 21 | # |
| 22 | +import copy |
22 | 23 | import itertools |
| 24 | +import tempfile |
23 | 25 |
|
24 | 26 | import numpy as np |
25 | 27 | import pytest |
@@ -91,6 +93,48 @@ def _make_model(dataset, model_architecture="offset10-model"): |
91 | 93 | # ) |
92 | 94 |
|
93 | 95 |
|
| 96 | +def _assert_same_state_dict(first, second): |
| 97 | + assert first.keys() == second.keys() |
| 98 | + for key in first: |
| 99 | + if isinstance(first[key], torch.Tensor): |
| 100 | + assert torch.allclose(first[key], second[key]), key |
| 101 | + elif isinstance(first[key], dict): |
| 102 | + _assert_same_state_dict(first[key], second[key]), key |
| 103 | + else: |
| 104 | + assert first[key] == second[key] |
| 105 | + |
| 106 | + |
| 107 | +def check_if_fit(model): |
| 108 | + """Check if a model was already fit. |
| 109 | +
|
| 110 | + Args: |
| 111 | + model: The model to check. |
| 112 | +
|
| 113 | + Returns: |
| 114 | + True if the model was already fit. |
| 115 | + """ |
| 116 | + return hasattr(model, "n_features_") |
| 117 | + |
| 118 | + |
| 119 | +def _assert_equal(original_solver, loaded_solver): |
| 120 | + for k in original_solver.model.state_dict(): |
| 121 | + assert original_solver.model.state_dict()[k].all( |
| 122 | + ) == loaded_solver.model.state_dict()[k].all() |
| 123 | + assert check_if_fit(loaded_solver) == check_if_fit(original_solver) |
| 124 | + |
| 125 | + if check_if_fit(loaded_solver): |
| 126 | + _assert_same_state_dict(original_solver.state_dict_, |
| 127 | + loaded_solver.state_dict_) |
| 128 | + X = np.random.normal(0, 1, (100, 1)) |
| 129 | + |
| 130 | + if loaded_solver.num_sessions is not None: |
| 131 | + assert np.allclose(loaded_solver.transform(X, session_id=0), |
| 132 | + original_solver.transform(X, session_id=0)) |
| 133 | + else: |
| 134 | + assert np.allclose(loaded_solver.transform(X), |
| 135 | + original_solver.transform(X)) |
| 136 | + |
| 137 | + |
94 | 138 | @pytest.mark.parametrize( |
95 | 139 | "data_name, loader_initfunc, model_architecture, solver_initfunc", |
96 | 140 | single_session_tests) |
@@ -144,6 +188,12 @@ def test_single_session(data_name, loader_initfunc, model_architecture, |
144 | 188 | for param in solver.parameters(): |
145 | 189 | assert isinstance(param, torch.Tensor) |
146 | 190 |
|
| 191 | + fitted_solver = copy.deepcopy(solver) |
| 192 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 193 | + solver.save(temp_dir) |
| 194 | + solver.load(temp_dir) |
| 195 | + _assert_equal(fitted_solver, solver) |
| 196 | + |
147 | 197 |
|
148 | 198 | @pytest.mark.parametrize( |
149 | 199 | "data_name, loader_initfunc, model_architecture, solver_initfunc", |
@@ -225,6 +275,12 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture, |
225 | 275 | for param in solver.parameters(): |
226 | 276 | assert isinstance(param, torch.Tensor) |
227 | 277 |
|
| 278 | + fitted_solver = copy.deepcopy(solver) |
| 279 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 280 | + solver.save(temp_dir) |
| 281 | + solver.load(temp_dir) |
| 282 | + _assert_equal(fitted_solver, solver) |
| 283 | + |
228 | 284 |
|
229 | 285 | @pytest.mark.parametrize( |
230 | 286 | "data_name, loader_initfunc, model_architecture, solver_initfunc", |
@@ -302,6 +358,12 @@ def test_multi_session(data_name, loader_initfunc, model_architecture, |
302 | 358 | for param in solver.parameters(): |
303 | 359 | assert isinstance(param, torch.Tensor) |
304 | 360 |
|
| 361 | + fitted_solver = copy.deepcopy(solver) |
| 362 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 363 | + solver.save(temp_dir) |
| 364 | + solver.load(temp_dir) |
| 365 | + _assert_equal(fitted_solver, solver) |
| 366 | + |
305 | 367 |
|
306 | 368 | @pytest.mark.parametrize( |
307 | 369 | "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", |
|
0 commit comments