Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,7 @@ jobs:
working-directory: ./docs/source-${{ matrix.pkg-name }}
# allow failing link check and doctest if you run with dispatch
continue-on-error: ${{ (matrix.target == 'doctest' || matrix.target == 'linkcheck') && github.event_name == 'workflow_dispatch' }}
run: |
# temp fix: https://github.com/Lightning-AI/pytorch-lightning/actions/runs/19440502586/job/55622388642?pr=21354#step:11:4596
uv pip install -U fastapi
make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"

- name: Keep artifact
if: github.event_name == 'pull_request'
Expand Down
5 changes: 4 additions & 1 deletion docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,10 @@ def package_list_from_file(file):
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
"""
coverage_skip_undoc_in_source = True

Expand Down
1 change: 0 additions & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@ fsspec[http] >=2022.5.0, <2025.11.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
pytorch-lightning-enterprise >=2.6.0
1 change: 0 additions & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ torchmetrics >0.7.0, <1.9.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
pytorch-lightning-enterprise >=2.6.0
123 changes: 96 additions & 27 deletions src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings
from typing import Any, Union

import torch
Expand All @@ -21,11 +20,7 @@

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.imports import _raise_enterprise_not_available

_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
from lightning.fabric.utilities.device_parser import _check_data_type


class XLAAccelerator(Accelerator):
Expand All @@ -36,38 +31,38 @@ class XLAAccelerator(Accelerator):
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
_raise_enterprise_not_available()
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
if not _using_pjrt():
raise RuntimeError("The XLA XRT runtime is not supported anymore.")
super().__init__(*args, **kwargs)

from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator

self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)

@override
def setup_device(self, device: torch.device) -> None:
return self.accelerator_impl.setup_device(device)
pass

@override
def teardown(self) -> None:
return self.accelerator_impl.teardown()
pass

@staticmethod
@override
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
"""Accelerator device parsing logic."""
_raise_enterprise_not_available()
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator

return EnterpriseXLAAccelerator.parse_devices(devices)
return _parse_tpu_devices(devices)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
_raise_enterprise_not_available()
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator

return EnterpriseXLAAccelerator.get_parallel_devices(devices)
devices = _parse_tpu_devices(devices)
if isinstance(devices, int):
return [torch.device("xla", i) for i in range(devices)]
# list of devices is not supported, just a specific index, fine to access [0]
return [torch.device("xla", devices[0])]
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy

@staticmethod
@override
Expand All @@ -76,10 +71,16 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
@functools.lru_cache(maxsize=1)
def auto_device_count() -> int:
"""Get the devices when set to auto."""
_raise_enterprise_not_available()
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
if not _XLA_AVAILABLE:
return 0
if _XLA_GREATER_EQUAL_2_1:
from torch_xla._internal import tpu

return tpu.num_available_devices()
from torch_xla.experimental import tpu

return EnterpriseXLAAccelerator.auto_device_count()
device_count_on_version = {2: 8, 3: 8, 4: 4}
return device_count_on_version.get(tpu.version(), 8)

@staticmethod
@override
Expand All @@ -91,9 +92,6 @@ def is_available() -> bool:
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
# when `torch_xla` is imported but not used
return False
except ModuleNotFoundError as e:
warnings.warn(str(e))
return False

@staticmethod
@override
Expand All @@ -108,3 +106,74 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
cls,
description=cls.__name__,
)


# PJRT support requires this minimum version
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")


def _using_pjrt() -> bool:
# `using_pjrt` is removed in torch_xla 2.5
if _XLA_GREATER_EQUAL_2_5:
from torch_xla import runtime as xr

return xr.device_type() is not None
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
if _XLA_GREATER_EQUAL_2_1:
from torch_xla import runtime as xr

return xr.using_pjrt()

from torch_xla.experimental import pjrt

return pjrt.using_pjrt()


def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
"""Parses the TPU devices given in the format as accepted by the
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.

Args:
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
A single element list of int or string can be used to indicate the specific TPU core to use.

Returns:
A list of tpu cores to be used.

"""
_check_data_type(devices)
if isinstance(devices, str):
devices = _parse_tpu_devices_str(devices)
_check_tpu_devices_valid(devices)
return devices


def _check_tpu_devices_valid(devices: object) -> None:
device_count = XLAAccelerator.auto_device_count()
if (
# support number of devices
isinstance(devices, int)
and devices in {1, device_count}
# support picking a specific device
or isinstance(devices, (list, tuple))
and len(devices) == 1
and 0 <= devices[0] <= device_count - 1
):
return
raise ValueError(
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
)


def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
devices = devices.strip()
try:
return int(devices)
except ValueError:
try:
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
except ValueError:
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")
28 changes: 10 additions & 18 deletions src/lightning/fabric/plugins/environments/kubeflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

import logging
import os

from typing_extensions import override

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.utilities.imports import _raise_enterprise_not_available

log = logging.getLogger(__name__)

Expand All @@ -33,28 +33,20 @@ class KubeflowEnvironment(ClusterEnvironment):

"""

def __init__(self) -> None:
_raise_enterprise_not_available()
from pytorch_lightning_enterprise.plugins.environments.kubeflow import (
KubeflowEnvironment as EnterpriseKubeflowEnvironment,
)

self.kubeflow_impl = EnterpriseKubeflowEnvironment()

@property
@override
def creates_processes_externally(self) -> bool:
return self.kubeflow_impl.creates_processes_externally
return True

@property
@override
def main_address(self) -> str:
return self.kubeflow_impl.main_address
return os.environ["MASTER_ADDR"]

@property
@override
def main_port(self) -> int:
return self.kubeflow_impl.main_port
return int(os.environ["MASTER_PORT"])

@staticmethod
@override
Expand All @@ -63,24 +55,24 @@ def detect() -> bool:

@override
def world_size(self) -> int:
return self.kubeflow_impl.world_size()
return int(os.environ["WORLD_SIZE"])

@override
def set_world_size(self, size: int) -> None:
return self.kubeflow_impl.set_world_size(size)
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

@override
def global_rank(self) -> int:
return self.kubeflow_impl.global_rank()
return int(os.environ["RANK"])

@override
def set_global_rank(self, rank: int) -> None:
return self.kubeflow_impl.set_global_rank(rank)
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

@override
def local_rank(self) -> int:
return self.kubeflow_impl.local_rank()
return 0

@override
def node_rank(self) -> int:
return self.kubeflow_impl.node_rank()
return self.global_rank()
Loading
Loading