Skip to content

chore: bump PyTorch version in dependencies & CI #21043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .azure/gpu-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ jobs:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.1"
PACKAGE_NAME: "fabric"
"Fabric | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.8-cuda12.6.3"
PACKAGE_NAME: "fabric"
#"Fabric | future":
# image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3"
# PACKAGE_NAME: "fabric"
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.8-cuda12.6.3"
PACKAGE_NAME: "lightning"
workspace:
clean: all
Expand Down
4 changes: 2 additions & 2 deletions .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ jobs:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.1"
PACKAGE_NAME: "pytorch"
"PyTorch | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.8-cuda12.6.3"
PACKAGE_NAME: "pytorch"
#"PyTorch | future":
# image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3"
# PACKAGE_NAME: "pytorch"
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.8-cuda12.6.3"
PACKAGE_NAME: "lightning"
pool: lit-rtx-3090
variables:
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/ci-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ jobs:
- { os: "ubuntu-22.04", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" }
- { os: "windows-2022", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" }
# "fabric" installs the standalone package
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" }
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
# adding recently cut Torch 2.7 - FUTURE
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" }
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
timeout-minutes: 25 # because of building grpcio on Mac
env:
PACKAGE_NAME: ${{ matrix.pkg-name }}
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ jobs:
- { os: "ubuntu-22.04", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" }
- { os: "windows-2022", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" }
# "pytorch" installs the standalone package
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" }
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
# adding recently cut Torch 2.7 - FUTURE
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" }
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
timeout-minutes: 50
env:
PACKAGE_NAME: ${{ matrix.pkg-name }}
Expand Down
2 changes: 1 addition & 1 deletion requirements/fabric/examples.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torchvision >=0.16.0, <0.23.0
torchvision >=0.16.0, <0.24.0
torchmetrics >=0.10.0, <1.9.0
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torch >=2.1.0, <=2.8.0
torch >=2.1.0, <2.9.0
tqdm >=4.57.0, <4.68.0
PyYAML >5.4, <6.1.0
fsspec[http] >=2022.5.0, <2025.8.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

requests <2.33.0
torchvision >=0.16.0, <0.23.0
torchvision >=0.16.0, <0.24.0
ipython[all] <8.19.0
torchmetrics >=0.10.0, <1.9.0
4 changes: 4 additions & 0 deletions src/lightning/fabric/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _runif_reasons(
standalone: bool = False,
deepspeed: bool = False,
dynamo: bool = False,
linux_only: bool = False,
) -> tuple[list[str], dict[str, bool]]:
"""Construct reasons for pytest skipif.

Expand Down Expand Up @@ -123,4 +124,7 @@ def _runif_reasons(
if not is_dynamo_supported():
reasons.append("torch.dynamo")

if linux_only and sys.platform != "linux":
reasons.append("only linux")

return reasons, kwargs
6 changes: 4 additions & 2 deletions src/lightning/pytorch/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from lightning_utilities.core.imports import RequirementCache

from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
from lightning.fabric.utilities.testing import _runif_reasons as _fabric_run_if
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE
Expand All @@ -42,6 +42,7 @@ def _runif_reasons(
psutil: bool = False,
sklearn: bool = False,
onnx: bool = False,
linux_only: bool = False,
) -> tuple[list[str], dict[str, bool]]:
"""Construct reasons for pytest skipif.

Expand All @@ -67,7 +68,7 @@ def _runif_reasons(

"""

reasons, kwargs = fabric_run_if(
reasons, kwargs = _fabric_run_if(
min_cuda_gpus=min_cuda_gpus,
min_torch=min_torch,
max_torch=max_torch,
Expand All @@ -79,6 +80,7 @@ def _runif_reasons(
standalone=standalone,
deepspeed=deepspeed,
dynamo=dynamo,
linux_only=linux_only,
)

if rich and not _RICH_AVAILABLE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def forward(self, x):
@pytest.mark.parametrize(
("accelerator", "precision", "expected_dtype"),
[
("cpu", "16-mixed", torch.bfloat16),
("cpu", "bf16-mixed", torch.bfloat16),
pytest.param("cpu", "16-mixed", torch.bfloat16, marks=RunIf(skip_windows=True)),
pytest.param("cpu", "bf16-mixed", torch.bfloat16, marks=RunIf(skip_windows=True)),
pytest.param("cuda", "16-mixed", torch.float16, marks=RunIf(min_cuda_gpus=2)),
pytest.param("cuda", "bf16-mixed", torch.bfloat16, marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self):
self.register_buffer("buffer", torch.ones(3))


@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
@RunIf(skip_windows=True)
@pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp_fork"])
def test_memory_sharing_disabled(strategy):
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_fabric/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
@pytest.mark.parametrize(
"accelerator",
[
"cpu",
pytest.param("cpu", marks=RunIf(skip_windows=True)),
pytest.param("cuda", marks=RunIf(min_cuda_gpus=2)),
],
)
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,10 @@ def test_collective_operations(devices, process):


@pytest.mark.skipif(
RequirementCache("torch<2.4") and RequirementCache("numpy>=2.0"),
RequirementCache("numpy>=2.0"),
reason="torch.distributed not compatible with numpy>=2.0",
)
@RunIf(min_torch="2.4", skip_windows=True)
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
def test_is_shared_filesystem(tmp_path, monkeypatch):
# In the non-distributed case, every location is interpreted as 'shared'
Expand Down
164 changes: 19 additions & 145 deletions tests/tests_fabric/utilities/test_spike.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import contextlib
import sys

import pytest
import torch

from lightning.fabric import Fabric
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
from tests_fabric.helpers.runif import RunIf


def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
Expand All @@ -32,6 +32,8 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
# NOTE FOR ALL FOLLOWING TESTS:
# adding run on linux only because multiprocessing on other platforms takes forever
[
pytest.param(0, 1, None, True),
pytest.param(0, 1, None, False),
Expand All @@ -41,150 +43,22 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
pytest.param(0, 1, float("-inf"), False),
pytest.param(0, 1, float("NaN"), True),
pytest.param(0, 1, float("NaN"), False),
pytest.param(
0,
2,
None,
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
None,
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
None,
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
None,
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("-inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("-inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("-inf"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("-inf"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("NaN"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
0,
2,
float("NaN"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("NaN"),
True,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
float("NaN"),
False,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(0, 2, None, True, marks=RunIf(linux_only=True)),
pytest.param(0, 2, None, False, marks=RunIf(linux_only=True)),
pytest.param(1, 2, None, True, marks=RunIf(linux_only=True)),
pytest.param(1, 2, None, False, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("inf"), True, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("inf"), False, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("inf"), True, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("inf"), False, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("-inf"), True, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("-inf"), False, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("-inf"), True, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("-inf"), False, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("NaN"), True, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("NaN"), False, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("NaN"), True, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("NaN"), False, marks=RunIf(linux_only=True)),
],
)
@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0")
Expand Down
Loading
Loading