Skip to content

Commit 0430e22

Browse files
committed
add weights_only arg to _run_standard_hparams_test
1 parent 5eacb6e commit 0430e22

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

tests/tests_pytorch/models/test_hparams.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, hparams, *my_args, **my_kwargs):
9494
# -------------------------
9595
# STANDARD TESTS
9696
# -------------------------
97-
def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False):
97+
def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only=True):
9898
"""Tests for the existence of an arg 'test_arg=14'."""
9999
obj = datamodule if issubclass(cls, LightningDataModule) else model
100100

@@ -108,22 +108,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr
108108

109109
# make sure the raw checkpoint saved the properties
110110
raw_checkpoint_path = _raw_checkpoint_path(trainer)
111-
raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False)
112-
# with torch.serialization.safe_globals([Container, DictConfig]):
113-
# raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=True)
111+
raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only)
114112

115113
assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
116114
assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
117115

118116
# verify that model loads correctly
119-
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
117+
obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only)
120118
assert obj2.hparams.test_arg == 14
121119

122120
assert isinstance(obj2.hparams, hparam_type)
123121

124122
if try_overwrite:
125123
# verify that we can overwrite the property
126-
obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
124+
obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only)
127125
assert obj3.hparams.test_arg == 78
128126

129127
return raw_checkpoint_path
@@ -178,9 +176,8 @@ def test_omega_conf_hparams(tmp_path, cls):
178176
assert isinstance(obj.hparams, Container)
179177

180178
# run standard test suite
181-
# with torch.serialization.safe_globals([Container, DictConfig]):
182-
raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule)
183-
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
179+
raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False)
180+
obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False)
184181

185182
assert isinstance(obj2.hparams, Container)
186183

0 commit comments

Comments
 (0)