Skip to content

Commit 4bfc519

Browse files
committed
Revert "try drop pickle warning"
This reverts commit 80a089d.
1 parent 80a089d commit 4bfc519

File tree

7 files changed

+24
-10
lines changed

7 files changed

+24
-10
lines changed

tests/tests_pytorch/callbacks/test_early_stopping.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,13 @@ def test_pickling():
193193
early_stopping = EarlyStopping(monitor="foo")
194194

195195
early_stopping_pickled = pickle.dumps(early_stopping)
196-
early_stopping_loaded = pickle.loads(early_stopping_pickled)
196+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
197+
early_stopping_loaded = pickle.loads(early_stopping_pickled)
197198
assert vars(early_stopping) == vars(early_stopping_loaded)
198199

199200
early_stopping_pickled = cloudpickle.dumps(early_stopping)
200-
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
201+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
202+
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
201203
assert vars(early_stopping) == vars(early_stopping_loaded)
202204

203205

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,13 @@ def test_pickling(tmp_path):
352352
ckpt = ModelCheckpoint(dirpath=tmp_path)
353353

354354
ckpt_pickled = pickle.dumps(ckpt)
355-
ckpt_loaded = pickle.loads(ckpt_pickled)
355+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
356+
ckpt_loaded = pickle.loads(ckpt_pickled)
356357
assert vars(ckpt) == vars(ckpt_loaded)
357358

358359
ckpt_pickled = cloudpickle.dumps(ckpt)
359-
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
360+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
361+
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
360362
assert vars(ckpt) == vars(ckpt_loaded)
361363

362364

tests/tests_pytorch/core/test_metric_result_integration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ def lightning_log(fx, *args, **kwargs):
254254
}
255255

256256
# make sure can be pickled
257-
pickle.loads(pickle.dumps(result))
257+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
258+
pickle.loads(pickle.dumps(result))
258259
# make sure can be torch.loaded
259260
filepath = str(tmp_path / "result")
260261
torch.save(result, filepath)

tests/tests_pytorch/helpers/test_datasets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args):
4444
mnist = dataset_cls(**args)
4545

4646
mnist_pickled = pickle.dumps(mnist)
47-
pickle.loads(mnist_pickled)
47+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
48+
pickle.loads(mnist_pickled)
4849

4950
mnist_pickled = cloudpickle.dumps(mnist)
50-
cloudpickle.loads(mnist_pickled)
51+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
52+
cloudpickle.loads(mnist_pickled)

tests/tests_pytorch/loggers/test_all.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,12 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
184184
trainer = Trainer(max_epochs=1, logger=logger)
185185
pkl_bytes = pickle.dumps(trainer)
186186

187-
trainer2 = pickle.loads(pkl_bytes)
187+
with (
188+
pytest.warns(FutureWarning, match="`weights_only=False`")
189+
if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
190+
else nullcontext()
191+
):
192+
trainer2 = pickle.loads(pkl_bytes)
188193
trainer2.logger.log_metrics({"acc": 1.0})
189194

190195
# make sure we restored properly

tests/tests_pytorch/loggers/test_logger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def test_multiple_loggers_pickle(tmp_path):
124124

125125
trainer = Trainer(logger=[logger1, logger2])
126126
pkl_bytes = pickle.dumps(trainer)
127-
trainer2 = pickle.loads(pkl_bytes)
127+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
128+
trainer2 = pickle.loads(pkl_bytes)
128129
for logger in trainer2.loggers:
129130
logger.log_metrics({"acc": 1.0}, 0)
130131

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def name(self):
162162
assert trainer.logger.experiment, "missing experiment"
163163
assert trainer.log_dir == logger.save_dir
164164
pkl_bytes = pickle.dumps(trainer)
165-
trainer2 = pickle.loads(pkl_bytes)
165+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
166+
trainer2 = pickle.loads(pkl_bytes)
166167

167168
assert os.environ["WANDB_MODE"] == "dryrun"
168169
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__

0 commit comments

Comments
 (0)