Skip to content

Commit 5eacb6e

Browse files
committed
wip: try safe_globals context manager for tests
1 parent 12bd0d6 commit 5eacb6e

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

tests/tests_pytorch/models/test_hparams.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr
109109
# make sure the raw checkpoint saved the properties
110110
raw_checkpoint_path = _raw_checkpoint_path(trainer)
111111
raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False)
112+
# with torch.serialization.safe_globals([Container, DictConfig]):
113+
# raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=True)
114+
112115
assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
113116
assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
114117

@@ -175,8 +178,10 @@ def test_omega_conf_hparams(tmp_path, cls):
175178
assert isinstance(obj.hparams, Container)
176179

177180
# run standard test suite
181+
# with torch.serialization.safe_globals([Container, DictConfig]):
178182
raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule)
179183
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
184+
180185
assert isinstance(obj2.hparams, Container)
181186

182187
# config specific tests

0 commit comments

Comments
 (0)