@@ -369,13 +369,17 @@ class DictConfSubClassBoringModel: ...
369369 BoringModelWithMixinAndInit ,
370370 ],
371371)
372- def test_collect_init_arguments (tmp_path , cls ):
372+ def test_collect_init_arguments (tmp_path , cls : BoringModel ):
373373 """Test that the model automatically saves the arguments passed into the constructor."""
374374 extra_args = {}
375+ weights_only = True
376+
375377 if cls is AggSubClassBoringModel :
376378 extra_args .update (my_loss = torch .nn .CosineEmbeddingLoss ())
379+ weights_only = False
377380 elif cls is DictConfSubClassBoringModel :
378381 extra_args .update (dict_conf = OmegaConf .create ({"my_param" : "anything" }))
382+ weights_only = False
379383
380384 model = cls (** extra_args )
381385 assert model .hparams .batch_size == 64
@@ -394,12 +398,12 @@ def test_collect_init_arguments(tmp_path, cls):
394398
395399 raw_checkpoint_path = _raw_checkpoint_path (trainer )
396400
397- raw_checkpoint = torch .load (raw_checkpoint_path , weights_only = False )
401+ raw_checkpoint = torch .load (raw_checkpoint_path , weights_only = weights_only )
398402 assert LightningModule .CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
399403 assert raw_checkpoint [LightningModule .CHECKPOINT_HYPER_PARAMS_KEY ]["batch_size" ] == 179
400404
401405 # verify that model loads correctly
402- model = cls .load_from_checkpoint (raw_checkpoint_path )
406+ model = cls .load_from_checkpoint (raw_checkpoint_path , weights_only = weights_only )
403407 assert model .hparams .batch_size == 179
404408
405409 if isinstance (model , AggSubClassBoringModel ):
@@ -410,7 +414,7 @@ def test_collect_init_arguments(tmp_path, cls):
410414 assert model .hparams .dict_conf ["my_param" ] == "anything"
411415
412416 # verify that we can overwrite whatever we want
413- model = cls .load_from_checkpoint (raw_checkpoint_path , batch_size = 99 )
417+ model = cls .load_from_checkpoint (raw_checkpoint_path , batch_size = 99 , weights_only = weights_only )
414418 assert model .hparams .batch_size == 99
415419
416420
0 commit comments