Skip to content

Commit 2d801f2

Browse files
Bordapre-commit-ci[bot]
authored andcommitted
fix(tests): update tests after torch 2.4.1 (#20302)
* update * test_loggers_pickle_all * more... * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit d1ca3c6)
1 parent ba4a4ef commit 2d801f2

File tree

10 files changed

+25
-20
lines changed

10 files changed

+25
-20
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ ignore = [
7676
"S108",
7777
"E203", # conflicts with black
7878
]
79-
ignore-init-module-imports = true
8079

8180
[tool.ruff.lint.per-file-ignores]
8281
".actions/*" = ["S101", "S310"]

requirements/typing.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mypy==1.11.0
2-
torch==2.4.0
2+
torch==2.4.1
33

44
types-Markdown
55
types-PyYAML

src/lightning/fabric/utilities/imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
3333
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
34+
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
3435
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
36+
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
3537

3638
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
3739

tests/tests_pytorch/callbacks/test_early_stopping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import cloudpickle
2424
import pytest
2525
import torch
26-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
26+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
2727
from lightning.pytorch import Trainer, seed_everything
2828
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
2929
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -193,12 +193,12 @@ def test_pickling():
193193
early_stopping = EarlyStopping(monitor="foo")
194194

195195
early_stopping_pickled = pickle.dumps(early_stopping)
196-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
196+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
197197
early_stopping_loaded = pickle.loads(early_stopping_pickled)
198198
assert vars(early_stopping) == vars(early_stopping_loaded)
199199

200200
early_stopping_pickled = cloudpickle.dumps(early_stopping)
201-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
201+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
202202
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
203203
assert vars(early_stopping) == vars(early_stopping_loaded)
204204

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import yaml
3333
from jsonargparse import ArgumentParser
3434
from lightning.fabric.utilities.cloud_io import _load as pl_load
35-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
35+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
3636
from lightning.pytorch import Trainer, seed_everything
3737
from lightning.pytorch.callbacks import ModelCheckpoint
3838
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -352,12 +352,12 @@ def test_pickling(tmp_path):
352352
ckpt = ModelCheckpoint(dirpath=tmp_path)
353353

354354
ckpt_pickled = pickle.dumps(ckpt)
355-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
355+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
356356
ckpt_loaded = pickle.loads(ckpt_pickled)
357357
assert vars(ckpt) == vars(ckpt_loaded)
358358

359359
ckpt_pickled = cloudpickle.dumps(ckpt)
360-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
360+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
361361
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
362362
assert vars(ckpt) == vars(ckpt_loaded)
363363

tests/tests_pytorch/core/test_metric_result_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import lightning.pytorch as pl
2020
import pytest
2121
import torch
22-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
22+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
2323
from lightning.fabric.utilities.warnings import PossibleUserWarning
2424
from lightning.pytorch import Trainer
2525
from lightning.pytorch.callbacks import OnExceptionCheckpoint
@@ -254,7 +254,7 @@ def lightning_log(fx, *args, **kwargs):
254254
}
255255

256256
# make sure can be pickled
257-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
257+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
258258
pickle.loads(pickle.dumps(result))
259259
# make sure can be torch.loaded
260260
filepath = str(tmp_path / "result")

tests/tests_pytorch/helpers/test_datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import cloudpickle
1818
import pytest
1919
import torch
20-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
20+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
2121

2222
from tests_pytorch import _PATH_DATASETS
2323
from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST
@@ -44,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args):
4444
mnist = dataset_cls(**args)
4545

4646
mnist_pickled = pickle.dumps(mnist)
47-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
47+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
4848
pickle.loads(mnist_pickled)
4949

5050
mnist_pickled = cloudpickle.dumps(mnist)
51-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
51+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
5252
cloudpickle.loads(mnist_pickled)

tests/tests_pytorch/loggers/test_all.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import pytest
2222
import torch
23-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
23+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1
2424
from lightning.pytorch import Callback, Trainer
2525
from lightning.pytorch.demos.boring_classes import BoringModel
2626
from lightning.pytorch.loggers import (
@@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class):
163163
pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")
164164

165165

166-
def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
166+
def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
167167
"""Verify that pickling trainer with logger works."""
168168
_patch_comet_atexit(monkeypatch)
169169

@@ -184,7 +184,11 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
184184
trainer = Trainer(max_epochs=1, logger=logger)
185185
pkl_bytes = pickle.dumps(trainer)
186186

187-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
187+
with (
188+
pytest.warns(FutureWarning, match="`weights_only=False`")
189+
if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
190+
else nullcontext()
191+
):
188192
trainer2 = pickle.loads(pkl_bytes)
189193
trainer2.logger.log_metrics({"acc": 1.0})
190194

tests/tests_pytorch/loggers/test_logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import pytest
2323
import torch
24-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
24+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
2525
from lightning.fabric.utilities.logger import _convert_params, _sanitize_params
2626
from lightning.pytorch import Trainer
2727
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
@@ -124,7 +124,7 @@ def test_multiple_loggers_pickle(tmp_path):
124124

125125
trainer = Trainer(logger=[logger1, logger2])
126126
pkl_bytes = pickle.dumps(trainer)
127-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
127+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
128128
trainer2 = pickle.loads(pkl_bytes)
129129
for logger in trainer2.loggers:
130130
logger.log_metrics({"acc": 1.0}, 0)

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pytest
2121
import yaml
22-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
22+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
2323
from lightning.pytorch import Trainer
2424
from lightning.pytorch.callbacks import ModelCheckpoint
2525
from lightning.pytorch.cli import LightningCLI
@@ -162,7 +162,7 @@ def name(self):
162162
assert trainer.logger.experiment, "missing experiment"
163163
assert trainer.log_dir == logger.save_dir
164164
pkl_bytes = pickle.dumps(trainer)
165-
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
165+
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
166166
trainer2 = pickle.loads(pkl_bytes)
167167

168168
assert os.environ["WANDB_MODE"] == "dryrun"

0 commit comments

Comments
 (0)