Skip to content

Commit f259e1a

Browse files
authored
Merge branch 'master' into update-pyproject-py310
2 parents 5157f73 + 9a10959 commit f259e1a

File tree

65 files changed

+1127
-4219
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1127
-4219
lines changed

.github/workflows/docs-build.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ jobs:
125125
working-directory: ./docs/source-${{ matrix.pkg-name }}
126126
# allow failing link check and doctest if you run with dispatch
127127
continue-on-error: ${{ (matrix.target == 'doctest' || matrix.target == 'linkcheck') && github.event_name == 'workflow_dispatch' }}
128-
run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
128+
run: |
129+
# temp fix: https://github.com/Lightning-AI/pytorch-lightning/actions/runs/19440502586/job/55622388642?pr=21354#step:11:4596
130+
uv pip install -U fastapi
131+
make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
129132
130133
- name: Keep artifact
131134
if: github.event_name == 'pull_request'

docs/source-pytorch/conf.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,10 +604,7 @@ def package_list_from_file(file):
604604
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
605605
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
606606
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
607-
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
608-
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
609-
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
610-
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
607+
from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
611608
"""
612609
coverage_skip_undoc_in_source = True
613610

requirements/fabric/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ fsspec[http] >=2022.5.0, <2025.11.0
66
packaging >=20.0, <=25.0
77
typing-extensions >4.5.0, <4.16.0
88
lightning-utilities >=0.10.0, <0.16.0
9+
pytorch-lightning-enterprise >=2.6.0

requirements/pytorch/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ torchmetrics >0.7.0, <1.9.0
99
packaging >=20.0, <=25.0
1010
typing-extensions >4.5.0, <4.16.0
1111
lightning-utilities >=0.10.0, <0.16.0
12+
pytorch-lightning-enterprise >=2.6.0

src/lightning/fabric/accelerators/xla.py

Lines changed: 26 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020

2121
from lightning.fabric.accelerators.accelerator import Accelerator
2222
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
23-
from lightning.fabric.utilities.device_parser import _check_data_type
23+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
24+
25+
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
26+
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
27+
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
2428

2529

2630
class XLAAccelerator(Accelerator):
@@ -31,38 +35,38 @@ class XLAAccelerator(Accelerator):
3135
"""
3236

3337
def __init__(self, *args: Any, **kwargs: Any) -> None:
34-
if not _XLA_AVAILABLE:
35-
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
36-
if not _using_pjrt():
37-
raise RuntimeError("The XLA XRT runtime is not supported anymore.")
38+
_raise_enterprise_not_available()
3839
super().__init__(*args, **kwargs)
3940

41+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
42+
43+
self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)
44+
4045
@override
4146
def setup_device(self, device: torch.device) -> None:
42-
pass
47+
return self.accelerator_impl.setup_device(device)
4348

4449
@override
4550
def teardown(self) -> None:
46-
pass
51+
return self.accelerator_impl.teardown()
4752

4853
@staticmethod
4954
@override
5055
def parse_devices(devices: int | str | list[int]) -> int | list[int]:
5156
"""Accelerator device parsing logic."""
52-
return _parse_tpu_devices(devices)
57+
_raise_enterprise_not_available()
58+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
59+
60+
return EnterpriseXLAAccelerator.parse_devices(devices)
5361

5462
@staticmethod
5563
@override
5664
def get_parallel_devices(devices: int | list[int]) -> list[torch.device]:
5765
"""Gets parallel devices for the Accelerator."""
58-
devices = _parse_tpu_devices(devices)
59-
if isinstance(devices, int):
60-
return [torch.device("xla", i) for i in range(devices)]
61-
# list of devices is not supported, just a specific index, fine to access [0]
62-
return [torch.device("xla", devices[0])]
63-
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
64-
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
65-
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
66+
_raise_enterprise_not_available()
67+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
68+
69+
return EnterpriseXLAAccelerator.get_parallel_devices(devices)
6670

6771
@staticmethod
6872
@override
@@ -71,16 +75,10 @@ def get_parallel_devices(devices: int | list[int]) -> list[torch.device]:
7175
@functools.lru_cache(maxsize=1)
7276
def auto_device_count() -> int:
7377
"""Get the devices when set to auto."""
74-
if not _XLA_AVAILABLE:
75-
return 0
76-
if _XLA_GREATER_EQUAL_2_1:
77-
from torch_xla._internal import tpu
78-
79-
return tpu.num_available_devices()
80-
from torch_xla.experimental import tpu
78+
_raise_enterprise_not_available()
79+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
8180

82-
device_count_on_version = {2: 8, 3: 8, 4: 4}
83-
return device_count_on_version.get(tpu.version(), 8)
81+
return EnterpriseXLAAccelerator.auto_device_count()
8482

8583
@staticmethod
8684
@override
@@ -92,6 +90,9 @@ def is_available() -> bool:
9290
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
9391
# when `torch_xla` is imported but not used
9492
return False
93+
except ModuleNotFoundError as e:
94+
warnings.warn(str(e))
95+
return False
9596

9697
@staticmethod
9798
@override
@@ -106,74 +107,3 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
106107
cls,
107108
description=cls.__name__,
108109
)
109-
110-
111-
# PJRT support requires this minimum version
112-
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
113-
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
114-
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
115-
116-
117-
def _using_pjrt() -> bool:
118-
# `using_pjrt` is removed in torch_xla 2.5
119-
if _XLA_GREATER_EQUAL_2_5:
120-
from torch_xla import runtime as xr
121-
122-
return xr.device_type() is not None
123-
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
124-
if _XLA_GREATER_EQUAL_2_1:
125-
from torch_xla import runtime as xr
126-
127-
return xr.using_pjrt()
128-
129-
from torch_xla.experimental import pjrt
130-
131-
return pjrt.using_pjrt()
132-
133-
134-
def _parse_tpu_devices(devices: int | str | list[int]) -> int | list[int]:
135-
"""Parses the TPU devices given in the format as accepted by the
136-
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
137-
138-
Args:
139-
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
140-
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
141-
A single element list of int or string can be used to indicate the specific TPU core to use.
142-
143-
Returns:
144-
A list of tpu cores to be used.
145-
146-
"""
147-
_check_data_type(devices)
148-
if isinstance(devices, str):
149-
devices = _parse_tpu_devices_str(devices)
150-
_check_tpu_devices_valid(devices)
151-
return devices
152-
153-
154-
def _check_tpu_devices_valid(devices: object) -> None:
155-
device_count = XLAAccelerator.auto_device_count()
156-
if (
157-
# support number of devices
158-
isinstance(devices, int)
159-
and devices in {1, device_count}
160-
# support picking a specific device
161-
or isinstance(devices, (list, tuple))
162-
and len(devices) == 1
163-
and 0 <= devices[0] <= device_count - 1
164-
):
165-
return
166-
raise ValueError(
167-
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
168-
)
169-
170-
171-
def _parse_tpu_devices_str(devices: str) -> int | list[int]:
172-
devices = devices.strip()
173-
try:
174-
return int(devices)
175-
except ValueError:
176-
try:
177-
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
178-
except ValueError:
179-
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")

src/lightning/fabric/plugins/environments/kubeflow.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import os
1716

1817
from typing_extensions import override
1918

2019
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
20+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2121

2222
log = logging.getLogger(__name__)
2323

@@ -33,20 +33,28 @@ class KubeflowEnvironment(ClusterEnvironment):
3333
3434
"""
3535

36+
def __init__(self) -> None:
37+
_raise_enterprise_not_available()
38+
from pytorch_lightning_enterprise.plugins.environments.kubeflow import (
39+
KubeflowEnvironment as EnterpriseKubeflowEnvironment,
40+
)
41+
42+
self.kubeflow_impl = EnterpriseKubeflowEnvironment()
43+
3644
@property
3745
@override
3846
def creates_processes_externally(self) -> bool:
39-
return True
47+
return self.kubeflow_impl.creates_processes_externally
4048

4149
@property
4250
@override
4351
def main_address(self) -> str:
44-
return os.environ["MASTER_ADDR"]
52+
return self.kubeflow_impl.main_address
4553

4654
@property
4755
@override
4856
def main_port(self) -> int:
49-
return int(os.environ["MASTER_PORT"])
57+
return self.kubeflow_impl.main_port
5058

5159
@staticmethod
5260
@override
@@ -55,24 +63,24 @@ def detect() -> bool:
5563

5664
@override
5765
def world_size(self) -> int:
58-
return int(os.environ["WORLD_SIZE"])
66+
return self.kubeflow_impl.world_size()
5967

6068
@override
6169
def set_world_size(self, size: int) -> None:
62-
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
70+
return self.kubeflow_impl.set_world_size(size)
6371

6472
@override
6573
def global_rank(self) -> int:
66-
return int(os.environ["RANK"])
74+
return self.kubeflow_impl.global_rank()
6775

6876
@override
6977
def set_global_rank(self, rank: int) -> None:
70-
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
78+
return self.kubeflow_impl.set_global_rank(rank)
7179

7280
@override
7381
def local_rank(self) -> int:
74-
return 0
82+
return self.kubeflow_impl.local_rank()
7583

7684
@override
7785
def node_rank(self) -> int:
78-
return self.global_rank()
86+
return self.kubeflow_impl.node_rank()

0 commit comments

Comments
 (0)