Skip to content

Commit 2c2ab9e

Browse files
committed
weights_only=False for omegaconf hparams test
1 parent a4c9efe commit 2c2ab9e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

tests/tests_pytorch/models/test_hparams.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from argparse import Namespace
2020
from dataclasses import dataclass, field
2121
from enum import Enum
22+
from typing import Optional
2223
from unittest import mock
2324

2425
import cloudpickle
@@ -94,7 +95,9 @@ def __init__(self, hparams, *my_args, **my_kwargs):
9495
# -------------------------
9596
# STANDARD TESTS
9697
# -------------------------
97-
def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwrite=False):
98+
def _run_standard_hparams_test(
99+
tmp_path, model, cls, datamodule=None, try_overwrite=False, weights_only: Optional[bool] = None
100+
):
98101
"""Tests for the existence of an arg 'test_arg=14'."""
99102
obj = datamodule if issubclass(cls, LightningDataModule) else model
100103

@@ -108,20 +111,20 @@ def _run_standard_hparams_test(tmp_path, model, cls, datamodule=None, try_overwr
108111

109112
# make sure the raw checkpoint saved the properties
110113
raw_checkpoint_path = _raw_checkpoint_path(trainer)
111-
raw_checkpoint = torch.load(raw_checkpoint_path)
114+
raw_checkpoint = torch.load(raw_checkpoint_path, weights_only=weights_only)
112115

113116
assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint
114117
assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14
115118

116119
# verify that model loads correctly
117-
obj2 = cls.load_from_checkpoint(raw_checkpoint_path)
120+
obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=weights_only)
118121
assert obj2.hparams.test_arg == 14
119122

120123
assert isinstance(obj2.hparams, hparam_type)
121124

122125
if try_overwrite:
123126
# verify that we can overwrite the property
124-
obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78)
127+
obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78, weights_only=weights_only)
125128
assert obj3.hparams.test_arg == 78
126129

127130
return raw_checkpoint_path
@@ -176,7 +179,8 @@ def test_omega_conf_hparams(tmp_path, cls):
176179
assert isinstance(obj.hparams, Container)
177180

178181
# run standard test suite
179-
raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule)
182+
# weights_only=False as omegaconf.DictConfig is not an allowed global by default
183+
raw_checkpoint_path = _run_standard_hparams_test(tmp_path, model, cls, datamodule=datamodule, weights_only=False)
180184
obj2 = cls.load_from_checkpoint(raw_checkpoint_path, weights_only=False)
181185

182186
assert isinstance(obj2.hparams, Container)

0 commit comments

Comments
 (0)