@@ -94,7 +94,7 @@ def __init__(self, hparams, *my_args, **my_kwargs):
9494# -------------------------
9595# STANDARD TESTS
9696# -------------------------
97- def _run_standard_hparams_test (tmp_path , model , cls , datamodule = None , try_overwrite = False ):
97+ def _run_standard_hparams_test (tmp_path , model , cls , datamodule = None , try_overwrite = False , weights_only = True ):
9898 """Tests for the existence of an arg 'test_arg=14'."""
9999 obj = datamodule if issubclass (cls , LightningDataModule ) else model
100100
@@ -108,22 +108,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr
108108
109109 # make sure the raw checkpoint saved the properties
110110 raw_checkpoint_path = _raw_checkpoint_path (trainer )
111- raw_checkpoint = torch .load (raw_checkpoint_path , weights_only = False )
112- # with torch.serialization.safe_globals([Container, DictConfig]):
113- # raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=True)
111+ raw_checkpoint = torch .load (raw_checkpoint_path , weights_only = weights_only )
114112
115113 assert cls .CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
116114 assert raw_checkpoint [cls .CHECKPOINT_HYPER_PARAMS_KEY ]["test_arg" ] == 14
117115
118116 # verify that model loads correctly
119- obj2 = cls .load_from_checkpoint (raw_checkpoint_path )
117+ obj2 = cls .load_from_checkpoint (raw_checkpoint_path , weights_only = weights_only )
120118 assert obj2 .hparams .test_arg == 14
121119
122120 assert isinstance (obj2 .hparams , hparam_type )
123121
124122 if try_overwrite :
125123 # verify that we can overwrite the property
126- obj3 = cls .load_from_checkpoint (raw_checkpoint_path , test_arg = 78 )
124+ obj3 = cls .load_from_checkpoint (raw_checkpoint_path , test_arg = 78 , weights_only = weights_only )
127125 assert obj3 .hparams .test_arg == 78
128126
129127 return raw_checkpoint_path
@@ -178,9 +176,8 @@ def test_omega_conf_hparams(tmp_path, cls):
178176 assert isinstance (obj .hparams , Container )
179177
180178 # run standard test suite
181- # with torch.serialization.safe_globals([Container, DictConfig]):
182- raw_checkpoint_path = _run_standard_hparams_test (tmp_path , model , cls , datamodule = datamodule )
183- obj2 = cls .load_from_checkpoint (raw_checkpoint_path )
179+ raw_checkpoint_path = _run_standard_hparams_test (tmp_path , model , cls , datamodule = datamodule , weights_only = False )
180+ obj2 = cls .load_from_checkpoint (raw_checkpoint_path , weights_only = False )
184181
185182 assert isinstance (obj2 .hparams , Container )
186183
0 commit comments