Skip to content

Commit 90af5cb

Browse files
Merge branch 'master' into fix/21357-modelparallel-checkpoint
2 parents 646e01b + 25c9922 commit 90af5cb

File tree

65 files changed

+4184
-1096
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+4184
-1096
lines changed

.github/workflows/docs-build.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,7 @@ jobs:
125125
working-directory: ./docs/source-${{ matrix.pkg-name }}
126126
# allow failing link check and doctest if you run with dispatch
127127
continue-on-error: ${{ (matrix.target == 'doctest' || matrix.target == 'linkcheck') && github.event_name == 'workflow_dispatch' }}
128-
run: |
129-
# temp fix: https://github.com/Lightning-AI/pytorch-lightning/actions/runs/19440502586/job/55622388642?pr=21354#step:11:4596
130-
uv pip install -U fastapi
131-
make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
128+
run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
132129

133130
- name: Keep artifact
134131
if: github.event_name == 'pull_request'

docs/source-pytorch/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,10 @@ def package_list_from_file(file):
604604
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
605605
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
606606
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
607-
from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
607+
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
608+
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
609+
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
610+
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
608611
"""
609612
coverage_skip_undoc_in_source = True
610613

requirements/fabric/base.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@ fsspec[http] >=2022.5.0, <2025.11.0
66
packaging >=20.0, <=25.0
77
typing-extensions >4.5.0, <4.16.0
88
lightning-utilities >=0.10.0, <0.16.0
9-
pytorch-lightning-enterprise >=2.6.0

requirements/pytorch/base.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ torchmetrics >0.7.0, <1.9.0
99
packaging >=20.0, <=25.0
1010
typing-extensions >4.5.0, <4.16.0
1111
lightning-utilities >=0.10.0, <0.16.0
12-
pytorch-lightning-enterprise >=2.6.0

src/lightning/fabric/accelerators/xla.py

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15-
import warnings
1615
from typing import Any, Union
1716

1817
import torch
@@ -21,11 +20,7 @@
2120

2221
from lightning.fabric.accelerators.accelerator import Accelerator
2322
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
24-
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
25-
26-
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
27-
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
28-
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
23+
from lightning.fabric.utilities.device_parser import _check_data_type
2924

3025

3126
class XLAAccelerator(Accelerator):
@@ -36,38 +31,38 @@ class XLAAccelerator(Accelerator):
3631
"""
3732

3833
def __init__(self, *args: Any, **kwargs: Any) -> None:
39-
_raise_enterprise_not_available()
34+
if not _XLA_AVAILABLE:
35+
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
36+
if not _using_pjrt():
37+
raise RuntimeError("The XLA XRT runtime is not supported anymore.")
4038
super().__init__(*args, **kwargs)
4139

42-
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
43-
44-
self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)
45-
4640
@override
4741
def setup_device(self, device: torch.device) -> None:
48-
return self.accelerator_impl.setup_device(device)
42+
pass
4943

5044
@override
5145
def teardown(self) -> None:
52-
return self.accelerator_impl.teardown()
46+
pass
5347

5448
@staticmethod
5549
@override
5650
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
5751
"""Accelerator device parsing logic."""
58-
_raise_enterprise_not_available()
59-
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
60-
61-
return EnterpriseXLAAccelerator.parse_devices(devices)
52+
return _parse_tpu_devices(devices)
6253

6354
@staticmethod
6455
@override
6556
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
6657
"""Gets parallel devices for the Accelerator."""
67-
_raise_enterprise_not_available()
68-
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
69-
70-
return EnterpriseXLAAccelerator.get_parallel_devices(devices)
58+
devices = _parse_tpu_devices(devices)
59+
if isinstance(devices, int):
60+
return [torch.device("xla", i) for i in range(devices)]
61+
# list of devices is not supported, just a specific index, fine to access [0]
62+
return [torch.device("xla", devices[0])]
63+
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
64+
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
65+
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
7166

7267
@staticmethod
7368
@override
@@ -76,10 +71,16 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
7671
@functools.lru_cache(maxsize=1)
7772
def auto_device_count() -> int:
7873
"""Get the devices when set to auto."""
79-
_raise_enterprise_not_available()
80-
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
74+
if not _XLA_AVAILABLE:
75+
return 0
76+
if _XLA_GREATER_EQUAL_2_1:
77+
from torch_xla._internal import tpu
78+
79+
return tpu.num_available_devices()
80+
from torch_xla.experimental import tpu
8181

82-
return EnterpriseXLAAccelerator.auto_device_count()
82+
device_count_on_version = {2: 8, 3: 8, 4: 4}
83+
return device_count_on_version.get(tpu.version(), 8)
8384

8485
@staticmethod
8586
@override
@@ -91,9 +92,6 @@ def is_available() -> bool:
9192
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
9293
# when `torch_xla` is imported but not used
9394
return False
94-
except ModuleNotFoundError as e:
95-
warnings.warn(str(e))
96-
return False
9795

9896
@staticmethod
9997
@override
@@ -108,3 +106,74 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
108106
cls,
109107
description=cls.__name__,
110108
)
109+
110+
111+
# PJRT support requires this minimum version
112+
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
113+
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
114+
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
115+
116+
117+
def _using_pjrt() -> bool:
118+
# `using_pjrt` is removed in torch_xla 2.5
119+
if _XLA_GREATER_EQUAL_2_5:
120+
from torch_xla import runtime as xr
121+
122+
return xr.device_type() is not None
123+
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
124+
if _XLA_GREATER_EQUAL_2_1:
125+
from torch_xla import runtime as xr
126+
127+
return xr.using_pjrt()
128+
129+
from torch_xla.experimental import pjrt
130+
131+
return pjrt.using_pjrt()
132+
133+
134+
def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
135+
"""Parses the TPU devices given in the format as accepted by the
136+
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
137+
138+
Args:
139+
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
140+
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
141+
A single element list of int or string can be used to indicate the specific TPU core to use.
142+
143+
Returns:
144+
A list of tpu cores to be used.
145+
146+
"""
147+
_check_data_type(devices)
148+
if isinstance(devices, str):
149+
devices = _parse_tpu_devices_str(devices)
150+
_check_tpu_devices_valid(devices)
151+
return devices
152+
153+
154+
def _check_tpu_devices_valid(devices: object) -> None:
155+
device_count = XLAAccelerator.auto_device_count()
156+
if (
157+
# support number of devices
158+
isinstance(devices, int)
159+
and devices in {1, device_count}
160+
# support picking a specific device
161+
or isinstance(devices, (list, tuple))
162+
and len(devices) == 1
163+
and 0 <= devices[0] <= device_count - 1
164+
):
165+
return
166+
raise ValueError(
167+
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
168+
)
169+
170+
171+
def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
172+
devices = devices.strip()
173+
try:
174+
return int(devices)
175+
except ValueError:
176+
try:
177+
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
178+
except ValueError:
179+
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")

src/lightning/fabric/plugins/environments/kubeflow.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
1617

1718
from typing_extensions import override
1819

1920
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
20-
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2121

2222
log = logging.getLogger(__name__)
2323

@@ -33,28 +33,20 @@ class KubeflowEnvironment(ClusterEnvironment):
3333
3434
"""
3535

36-
def __init__(self) -> None:
37-
_raise_enterprise_not_available()
38-
from pytorch_lightning_enterprise.plugins.environments.kubeflow import (
39-
KubeflowEnvironment as EnterpriseKubeflowEnvironment,
40-
)
41-
42-
self.kubeflow_impl = EnterpriseKubeflowEnvironment()
43-
4436
@property
4537
@override
4638
def creates_processes_externally(self) -> bool:
47-
return self.kubeflow_impl.creates_processes_externally
39+
return True
4840

4941
@property
5042
@override
5143
def main_address(self) -> str:
52-
return self.kubeflow_impl.main_address
44+
return os.environ["MASTER_ADDR"]
5345

5446
@property
5547
@override
5648
def main_port(self) -> int:
57-
return self.kubeflow_impl.main_port
49+
return int(os.environ["MASTER_PORT"])
5850

5951
@staticmethod
6052
@override
@@ -63,24 +55,24 @@ def detect() -> bool:
6355

6456
@override
6557
def world_size(self) -> int:
66-
return self.kubeflow_impl.world_size()
58+
return int(os.environ["WORLD_SIZE"])
6759

6860
@override
6961
def set_world_size(self, size: int) -> None:
70-
return self.kubeflow_impl.set_world_size(size)
62+
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
7163

7264
@override
7365
def global_rank(self) -> int:
74-
return self.kubeflow_impl.global_rank()
66+
return int(os.environ["RANK"])
7567

7668
@override
7769
def set_global_rank(self, rank: int) -> None:
78-
return self.kubeflow_impl.set_global_rank(rank)
70+
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
7971

8072
@override
8173
def local_rank(self) -> int:
82-
return self.kubeflow_impl.local_rank()
74+
return 0
8375

8476
@override
8577
def node_rank(self) -> int:
86-
return self.kubeflow_impl.node_rank()
78+
return self.global_rank()

0 commit comments

Comments
 (0)