Skip to content

Commit 07d2c4e

Browse files
authored
Merge branch 'master' into dependabot-pip-requirements-torch-2.9.0
2 parents 53265fb + 8f702b3 commit 07d2c4e

File tree

88 files changed

+1736
-4303
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+1736
-4303
lines changed

.github/CODEOWNERS

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
# the repo. Unless a later match takes precedence,
66
# @global-owner1 and @global-owner2 will be requested for
77
# review when someone opens a pull request.
8-
* @lantiga @borda @tchaton @justusschock @ethanwharris
8+
* @lantiga @tchaton @justusschock @ethanwharris
99

1010
# Docs
11-
/.github/*.md @williamfalcon @lantiga @borda
11+
/.github/*.md @williamfalcon @lantiga
1212
/docs/source-fabric/index.rst @williamfalcon @lantiga
1313
/docs/source-pytorch/index.rst @williamfalcon @lantiga
1414
/docs/source-pytorch/levels @williamfalcon @lantiga
1515

1616
/.github/CODEOWNERS @williamfalcon
1717
/SECURITY.md @williamfalcon @lantiga
1818
/README.md @williamfalcon @lantiga
19-
/src/pytorch_lightning/__about__.py @williamfalcon @lantiga @borda
20-
/src/lightning_fabric/__about__.py @williamfalcon @lantiga @borda
19+
/src/pytorch_lightning/__about__.py @williamfalcon @lantiga
20+
/src/lightning_fabric/__about__.py @williamfalcon @lantiga
2121

2222
/src/lightning/fabric/loggers @williamfalcon
2323
/src/lightning/pytorch/loggers @williamfalcon

.github/CONTRIBUTING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ We welcome any useful contribution! For your convenience here's a recommended wo
212212
- [Test README](https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/README.md)
213213
- [CI/CD README](https://github.com/Lightning-AI/pytorch-lightning/tree/master/.github/workflows#readme)
214214

215+
1. Once you have a PR opened (and thereby a PR number), please update the respective changelog for [fabric](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/CHANGELOG.md) or [pytorch](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/CHANGELOG.md) subpackage depending on where you made your changes.
216+
215217
1. When you feel ready for integrating your work, mark your PR "Ready for review".
216218

217219
- Your code should be readable and follow the project's design principles.

.github/workflows/docs-build.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ jobs:
125125
working-directory: ./docs/source-${{ matrix.pkg-name }}
126126
# allow failing link check and doctest if you run with dispatch
127127
continue-on-error: ${{ (matrix.target == 'doctest' || matrix.target == 'linkcheck') && github.event_name == 'workflow_dispatch' }}
128-
run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
128+
run: |
129+
# temp fix: https://github.com/Lightning-AI/pytorch-lightning/actions/runs/19440502586/job/55622388642?pr=21354#step:11:4596
130+
uv pip install -U fastapi
131+
make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
129132
130133
- name: Keep artifact
131134
if: github.event_name == 'pull_request'

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export PACKAGE_NAME=pytorch
1010

1111
# In Lightning Studio, the `lightning` package comes pre-installed.
1212
# Uninstall it first to ensure the editable install works correctly.
13-
setup:
13+
setup: update
1414
uv pip uninstall lightning pytorch-lightning lightning-fabric || true
1515
uv pip install -r requirements.txt \
1616
-r requirements/pytorch/base.txt \

docs/source-pytorch/community/governance.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ Role: All final decisions related to Lightning.
1919
Maintainers
2020
-----------
2121
- Luca Antiga (`lantiga <https://github.com/lantiga>`_)
22-
- Jirka Borovec (`Borda <https://github.com/Borda>`_)
22+
- Ethan Harris (`ethanwharris <https://github.com/ethanwharris>`_) (Torchbearer founder)
2323
- Justus Schock (`justusschock <https://github.com/justusschock>`_)
2424

2525

2626
Emeritus Maintainers
2727
--------------------
28-
- Ethan Harris (`ethanwharris <https://github.com/ethanwharris>`_) (Torchbearer founder)
2928
- Nicki Skafte (`SkafteNicki <https://github.com/SkafteNicki>`_)
3029
- Thomas Chaton (`tchaton <https://github.com/tchaton>`_)
30+
- Jirka Borovec (`Borda <https://github.com/Borda>`_)
3131

3232

3333
Alumni

docs/source-pytorch/conf.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,10 +604,7 @@ def package_list_from_file(file):
604604
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
605605
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
606606
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
607-
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
608-
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
609-
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
610-
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
607+
from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
611608
"""
612609
coverage_skip_undoc_in_source = True
613610

requirements/fabric/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ fsspec[http] >=2022.5.0, <2025.11.0
66
packaging >=20.0, <=25.0
77
typing-extensions >4.5.0, <4.16.0
88
lightning-utilities >=0.10.0, <0.16.0
9+
pytorch-lightning-enterprise >=2.6.0

requirements/pytorch/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ torchmetrics >0.7.0, <1.9.0
99
packaging >=20.0, <=25.0
1010
typing-extensions >4.5.0, <4.16.0
1111
lightning-utilities >=0.10.0, <0.16.0
12+
pytorch-lightning-enterprise >=2.6.0

src/lightning/fabric/CHANGELOG.md

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,7 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7-
8-
## [unreleased] - YYYY-MM-DD
9-
10-
### Added
11-
12-
-
13-
14-
15-
### Removed
16-
17-
-
18-
7+
## [2.6.0] - 2025-11-21
198

209
### Changed
2110

@@ -25,7 +14,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2514

2615
### Fixed
2716

28-
-
17+
- Fixed issue in detecting MPIEnvironment with partial mpi4py installation ([#21353](https://github.com/Lightning-AI/pytorch-lightning/pull/21353))
18+
19+
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
20+
21+
22+
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
2923

3024

3125
---

src/lightning/fabric/accelerators/xla.py

Lines changed: 27 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15+
import warnings
1516
from typing import Any, Union
1617

1718
import torch
@@ -20,7 +21,11 @@
2021

2122
from lightning.fabric.accelerators.accelerator import Accelerator
2223
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
23-
from lightning.fabric.utilities.device_parser import _check_data_type
24+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
25+
26+
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
27+
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
28+
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
2429

2530

2631
class XLAAccelerator(Accelerator):
@@ -31,38 +36,38 @@ class XLAAccelerator(Accelerator):
3136
"""
3237

3338
def __init__(self, *args: Any, **kwargs: Any) -> None:
34-
if not _XLA_AVAILABLE:
35-
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
36-
if not _using_pjrt():
37-
raise RuntimeError("The XLA XRT runtime is not supported anymore.")
39+
_raise_enterprise_not_available()
3840
super().__init__(*args, **kwargs)
3941

42+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
43+
44+
self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)
45+
4046
@override
4147
def setup_device(self, device: torch.device) -> None:
42-
pass
48+
return self.accelerator_impl.setup_device(device)
4349

4450
@override
4551
def teardown(self) -> None:
46-
pass
52+
return self.accelerator_impl.teardown()
4753

4854
@staticmethod
4955
@override
5056
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
5157
"""Accelerator device parsing logic."""
52-
return _parse_tpu_devices(devices)
58+
_raise_enterprise_not_available()
59+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
60+
61+
return EnterpriseXLAAccelerator.parse_devices(devices)
5362

5463
@staticmethod
5564
@override
5665
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
5766
"""Gets parallel devices for the Accelerator."""
58-
devices = _parse_tpu_devices(devices)
59-
if isinstance(devices, int):
60-
return [torch.device("xla", i) for i in range(devices)]
61-
# list of devices is not supported, just a specific index, fine to access [0]
62-
return [torch.device("xla", devices[0])]
63-
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
64-
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
65-
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
67+
_raise_enterprise_not_available()
68+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
69+
70+
return EnterpriseXLAAccelerator.get_parallel_devices(devices)
6671

6772
@staticmethod
6873
@override
@@ -71,16 +76,10 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
7176
@functools.lru_cache(maxsize=1)
7277
def auto_device_count() -> int:
7378
"""Get the devices when set to auto."""
74-
if not _XLA_AVAILABLE:
75-
return 0
76-
if _XLA_GREATER_EQUAL_2_1:
77-
from torch_xla._internal import tpu
78-
79-
return tpu.num_available_devices()
80-
from torch_xla.experimental import tpu
79+
_raise_enterprise_not_available()
80+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
8181

82-
device_count_on_version = {2: 8, 3: 8, 4: 4}
83-
return device_count_on_version.get(tpu.version(), 8)
82+
return EnterpriseXLAAccelerator.auto_device_count()
8483

8584
@staticmethod
8685
@override
@@ -92,6 +91,9 @@ def is_available() -> bool:
9291
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
9392
# when `torch_xla` is imported but not used
9493
return False
94+
except ModuleNotFoundError as e:
95+
warnings.warn(str(e))
96+
return False
9597

9698
@staticmethod
9799
@override
@@ -106,74 +108,3 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
106108
cls,
107109
description=cls.__name__,
108110
)
109-
110-
111-
# PJRT support requires this minimum version
112-
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
113-
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
114-
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
115-
116-
117-
def _using_pjrt() -> bool:
118-
# `using_pjrt` is removed in torch_xla 2.5
119-
if _XLA_GREATER_EQUAL_2_5:
120-
from torch_xla import runtime as xr
121-
122-
return xr.device_type() is not None
123-
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
124-
if _XLA_GREATER_EQUAL_2_1:
125-
from torch_xla import runtime as xr
126-
127-
return xr.using_pjrt()
128-
129-
from torch_xla.experimental import pjrt
130-
131-
return pjrt.using_pjrt()
132-
133-
134-
def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
135-
"""Parses the TPU devices given in the format as accepted by the
136-
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
137-
138-
Args:
139-
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
140-
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
141-
A single element list of int or string can be used to indicate the specific TPU core to use.
142-
143-
Returns:
144-
A list of tpu cores to be used.
145-
146-
"""
147-
_check_data_type(devices)
148-
if isinstance(devices, str):
149-
devices = _parse_tpu_devices_str(devices)
150-
_check_tpu_devices_valid(devices)
151-
return devices
152-
153-
154-
def _check_tpu_devices_valid(devices: object) -> None:
155-
device_count = XLAAccelerator.auto_device_count()
156-
if (
157-
# support number of devices
158-
isinstance(devices, int)
159-
and devices in {1, device_count}
160-
# support picking a specific device
161-
or isinstance(devices, (list, tuple))
162-
and len(devices) == 1
163-
and 0 <= devices[0] <= device_count - 1
164-
):
165-
return
166-
raise ValueError(
167-
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
168-
)
169-
170-
171-
def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
172-
devices = devices.strip()
173-
try:
174-
return int(devices)
175-
except ValueError:
176-
try:
177-
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
178-
except ValueError:
179-
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")

0 commit comments

Comments
 (0)