1919from argparse import Namespace
2020from dataclasses import dataclass , field
2121from enum import Enum
22+ from typing import Optional
2223from unittest import mock
2324
2425import cloudpickle
@@ -94,7 +95,9 @@ def __init__(self, hparams, *my_args, **my_kwargs):
9495# -------------------------
9596# STANDARD TESTS
9697# -------------------------
97- def _run_standard_hparams_test (tmp_path , model , cls , datamodule = None , try_overwrite = False ):
98+ def _run_standard_hparams_test (
99+ tmp_path , model , cls , datamodule = None , try_overwrite = False , weights_only : Optional [bool ] = None
100+ ):
98101 """Tests for the existence of an arg 'test_arg=14'."""
99102 obj = datamodule if issubclass (cls , LightningDataModule ) else model
100103
@@ -108,20 +111,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr
108111
109112 # make sure the raw checkpoint saved the properties
110113 raw_checkpoint_path = _raw_checkpoint_path (trainer )
111- raw_checkpoint = torch .load (raw_checkpoint_path )
114+ raw_checkpoint = torch .load (raw_checkpoint_path , weights_only = weights_only )
112115
113116 assert cls .CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
114117 assert raw_checkpoint [cls .CHECKPOINT_HYPER_PARAMS_KEY ]["test_arg" ] == 14
115118
116119 # verify that model loads correctly
117- obj2 = cls .load_from_checkpoint (raw_checkpoint_path )
120+ obj2 = cls .load_from_checkpoint (raw_checkpoint_path , weights_only = weights_only )
118121 assert obj2 .hparams .test_arg == 14
119122
120123 assert isinstance (obj2 .hparams , hparam_type )
121124
122125 if try_overwrite :
123126 # verify that we can overwrite the property
124- obj3 = cls .load_from_checkpoint (raw_checkpoint_path , test_arg = 78 )
127+ obj3 = cls .load_from_checkpoint (raw_checkpoint_path , test_arg = 78 , weights_only = weights_only )
125128 assert obj3 .hparams .test_arg == 78
126129
127130 return raw_checkpoint_path
@@ -176,7 +179,8 @@ def test_omega_conf_hparams(tmp_path, cls):
176179 assert isinstance (obj .hparams , Container )
177180
178181 # run standard test suite
179- raw_checkpoint_path = _run_standard_hparams_test (tmp_path , model , cls , datamodule = datamodule )
182+ # weights_only=False as omegaconf.DictConfig is not an allowed global by default
183+ raw_checkpoint_path = _run_standard_hparams_test (tmp_path , model , cls , datamodule = datamodule , weights_only = False )
180184 obj2 = cls .load_from_checkpoint (raw_checkpoint_path , weights_only = False )
181185
182186 assert isinstance (obj2 .hparams , Container )
0 commit comments