Skip to content

Commit 2abe915

Browse files
committed
weights_only=False when adding extra_args
1 parent 0430e22 commit 2abe915

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/tests_pytorch/models/test_hparams.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,17 @@ class DictConfSubClassBoringModel: ...
369369
BoringModelWithMixinAndInit,
370370
],
371371
)
372-
def test_collect_init_arguments(tmp_path, cls):
372+
def test_collect_init_arguments(tmp_path, cls: BoringModel):
373373
"""Test that the model automatically saves the arguments passed into the constructor."""
374374
extra_args = {}
375+
weights_only = True
376+
375377
if cls is AggSubClassBoringModel:
376378
extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss())
379+
weights_only = False
377380
elif cls is DictConfSubClassBoringModel:
378381
extra_args.update(dict_conf=OmegaConf.create({"my_param": "anything"}))
382+
weights_only = False
379383

380384
model = cls(**extra_args)
381385
assert model.hparams.batch_size == 64
@@ -394,12 +398,12 @@ def test_collect_init_arguments(tmp_path, cls):
394398

395399
raw_checkpoint_path = _raw_checkpoint_path(trainer)
396400

397-
raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=False)
401+
raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only)
398402
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
399403
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["batch_size"] == 179
400404

401405
# verify that model loads correctly
402-
model = cls.load_from_checkpoint(raw_checkpoint_path)
406+
model = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only)
403407
assert model.hparams.batch_size == 179
404408

405409
if isinstance(model, AggSubClassBoringModel):
@@ -410,7 +414,7 @@ def test_collect_init_arguments(tmp_path, cls):
410414
assert model.hparams.dict_conf["my_param"] == "anything"
411415

412416
# verify that we can overwrite whatever we want
413-
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
417+
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99, weights_only=weights_only)
414418
assert model.hparams.batch_size == 99
415419

416420

0 commit comments

Comments
 (0)