19
19
from argparse import Namespace
20
20
from dataclasses import dataclass , field
21
21
from enum import Enum
22
+ from typing import Optional
22
23
from unittest import mock
23
24
24
25
import cloudpickle
@@ -94,7 +95,9 @@ def __init__(self, hparams, *my_args, **my_kwargs):
94
95
# -------------------------
95
96
# STANDARD TESTS
96
97
# -------------------------
97
- def _run_standard_hparams_test (tmp_path , model , cls , datamodule = None , try_overwrite = False ):
98
+ def _run_standard_hparams_test (
99
+ tmp_path , model , cls , datamodule = None , try_overwrite = False , weights_only : Optional [bool ] = None
100
+ ):
98
101
"""Tests for the existence of an arg 'test_arg=14'."""
99
102
obj = datamodule if issubclass (cls , LightningDataModule ) else model
100
103
@@ -108,20 +111,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr
108
111
109
112
# make sure the raw checkpoint saved the properties
110
113
raw_checkpoint_path = _raw_checkpoint_path (trainer )
111
- raw_checkpoint = torch .load (raw_checkpoint_path )
114
+ raw_checkpoint = torch .load (raw_checkpoint_path , weights_only = weights_only )
112
115
113
116
assert cls .CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
114
117
assert raw_checkpoint [cls .CHECKPOINT_HYPER_PARAMS_KEY ]["test_arg" ] == 14
115
118
116
119
# verify that model loads correctly
117
- obj2 = cls .load_from_checkpoint (raw_checkpoint_path )
120
+ obj2 = cls .load_from_checkpoint (raw_checkpoint_path , weights_only = weights_only )
118
121
assert obj2 .hparams .test_arg == 14
119
122
120
123
assert isinstance (obj2 .hparams , hparam_type )
121
124
122
125
if try_overwrite :
123
126
# verify that we can overwrite the property
124
- obj3 = cls .load_from_checkpoint (raw_checkpoint_path , test_arg = 78 )
127
+ obj3 = cls .load_from_checkpoint (raw_checkpoint_path , test_arg = 78 , weights_only = weights_only )
125
128
assert obj3 .hparams .test_arg == 78
126
129
127
130
return raw_checkpoint_path
@@ -176,7 +179,8 @@ def test_omega_conf_hparams(tmp_path, cls):
176
179
assert isinstance (obj .hparams , Container )
177
180
178
181
# run standard test suite
179
- raw_checkpoint_path = _run_standard_hparams_test (tmp_path , model , cls , datamodule = datamodule )
182
+ # weights_only=False as omegaconf.DictConfig is not an allowed global by default
183
+ raw_checkpoint_path = _run_standard_hparams_test (tmp_path , model , cls , datamodule = datamodule , weights_only = False )
180
184
obj2 = cls .load_from_checkpoint (raw_checkpoint_path , weights_only = False )
181
185
182
186
assert isinstance (obj2 .hparams , Container )
0 commit comments