Skip to content
Open

[WIP] #21354

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c0a70aa
forward xla impl
justusschock Nov 10, 2025
708741c
forward logger implementation
justusschock Nov 10, 2025
f6e1c90
forward logger implementation: mlflow
justusschock Nov 10, 2025
a4ae0be
update neptune logger
justusschock Nov 10, 2025
8b61cb0
forward kubeflow implementation
justusschock Nov 10, 2025
aeae617
forward lsf env
justusschock Nov 10, 2025
65a85c2
move torchelastic
justusschock Nov 10, 2025
23deaf8
update xla env
justusschock Nov 10, 2025
73e2203
forward bitsandbytes
justusschock Nov 10, 2025
7dbab8e
forward deepspeed precision
justusschock Nov 10, 2025
61f01c6
forward transformer engine
justusschock Nov 10, 2025
2db8b8e
forward XLA precision
justusschock Nov 10, 2025
c63d855
forward deepspeed strategy fabric
justusschock Nov 10, 2025
467b935
integrate xla strategies
justusschock Nov 10, 2025
ed8f618
update pytorch deepspeed precision
justusschock Nov 10, 2025
e68c226
forward trainer xla single device
justusschock Nov 10, 2025
ebe167e
XLA ddp trainer
justusschock Nov 10, 2025
3aa1981
update
justusschock Nov 11, 2025
975c098
update
justusschock Nov 11, 2025
d9d4b69
update
justusschock Nov 11, 2025
60a5a75
update
justusschock Nov 11, 2025
f91447a
update
justusschock Nov 11, 2025
b834816
update
justusschock Nov 11, 2025
ba32712
update
justusschock Nov 11, 2025
0fc321e
update
justusschock Nov 11, 2025
4b6075c
update
justusschock Nov 11, 2025
555f531
update
justusschock Nov 11, 2025
bc33500
update
justusschock Nov 11, 2025
304f82b
update
justusschock Nov 11, 2025
8d5ac23
update
justusschock Nov 11, 2025
20f9049
update
justusschock Nov 11, 2025
644f354
update fabric tests
justusschock Nov 11, 2025
dab4fb5
fabric tests
justusschock Nov 12, 2025
a10808d
tests
justusschock Nov 12, 2025
01ec0fa
update version
justusschock Nov 12, 2025
a178f4c
update
justusschock Nov 12, 2025
c7d494a
update
justusschock Nov 12, 2025
aae4ca6
update
justusschock Nov 12, 2025
3078156
update
justusschock Nov 12, 2025
dab8afa
update
justusschock Nov 12, 2025
a3cf88c
update
justusschock Nov 12, 2025
2d208dc
fix doc issue
deependujha Nov 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,7 @@ def package_list_from_file(file):
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
"""
coverage_skip_undoc_in_source = True

Expand Down
1 change: 1 addition & 0 deletions requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ fsspec[http] >=2022.5.0, <2025.11.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
pytorch-lightning-enterprise >=0.0.1dev4
1 change: 1 addition & 0 deletions requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ torchmetrics >0.7.0, <1.9.0
packaging >=20.0, <=25.0
typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
pytorch-lightning-enterprise >=0.0.1dev4
123 changes: 27 additions & 96 deletions src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings
from typing import Any, Union

import torch
Expand All @@ -20,7 +21,11 @@

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.device_parser import _check_data_type
from lightning.fabric.utilities.imports import _raise_enterprise_not_available

_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")


class XLAAccelerator(Accelerator):
Expand All @@ -31,38 +36,38 @@ class XLAAccelerator(Accelerator):
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
if not _using_pjrt():
raise RuntimeError("The XLA XRT runtime is not supported anymore.")
_raise_enterprise_not_available()
super().__init__(*args, **kwargs)

from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator

self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)

@override
def setup_device(self, device: torch.device) -> None:
pass
return self.accelerator_impl.setup_device(device)

@override
def teardown(self) -> None:
pass
return self.accelerator_impl.teardown()

@staticmethod
@override
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
"""Accelerator device parsing logic."""
return _parse_tpu_devices(devices)
_raise_enterprise_not_available()
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator

return EnterpriseXLAAccelerator.parse_devices(devices)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_tpu_devices(devices)
if isinstance(devices, int):
return [torch.device("xla", i) for i in range(devices)]
# list of devices is not supported, just a specific index, fine to access [0]
return [torch.device("xla", devices[0])]
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
_raise_enterprise_not_available()
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator

return EnterpriseXLAAccelerator.get_parallel_devices(devices)

@staticmethod
@override
Expand All @@ -71,16 +76,10 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
@functools.lru_cache(maxsize=1)
def auto_device_count() -> int:
"""Get the devices when set to auto."""
if not _XLA_AVAILABLE:
return 0
if _XLA_GREATER_EQUAL_2_1:
from torch_xla._internal import tpu

return tpu.num_available_devices()
from torch_xla.experimental import tpu
_raise_enterprise_not_available()
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator

device_count_on_version = {2: 8, 3: 8, 4: 4}
return device_count_on_version.get(tpu.version(), 8)
return EnterpriseXLAAccelerator.auto_device_count()

@staticmethod
@override
Expand All @@ -92,6 +91,9 @@ def is_available() -> bool:
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
# when `torch_xla` is imported but not used
return False
except ModuleNotFoundError as e:
warnings.warn(str(e))
return False

@staticmethod
@override
Expand All @@ -106,74 +108,3 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
cls,
description=cls.__name__,
)


# PJRT support requires this minimum version
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")


def _using_pjrt() -> bool:
# `using_pjrt` is removed in torch_xla 2.5
if _XLA_GREATER_EQUAL_2_5:
from torch_xla import runtime as xr

return xr.device_type() is not None
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
if _XLA_GREATER_EQUAL_2_1:
from torch_xla import runtime as xr

return xr.using_pjrt()

from torch_xla.experimental import pjrt

return pjrt.using_pjrt()


def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
"""Parses the TPU devices given in the format as accepted by the
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.

Args:
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
A single element list of int or string can be used to indicate the specific TPU core to use.

Returns:
A list of tpu cores to be used.

"""
_check_data_type(devices)
if isinstance(devices, str):
devices = _parse_tpu_devices_str(devices)
_check_tpu_devices_valid(devices)
return devices


def _check_tpu_devices_valid(devices: object) -> None:
device_count = XLAAccelerator.auto_device_count()
if (
# support number of devices
isinstance(devices, int)
and devices in {1, device_count}
# support picking a specific device
or isinstance(devices, (list, tuple))
and len(devices) == 1
and 0 <= devices[0] <= device_count - 1
):
return
raise ValueError(
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
)


def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
devices = devices.strip()
try:
return int(devices)
except ValueError:
try:
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
except ValueError:
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")
28 changes: 18 additions & 10 deletions src/lightning/fabric/plugins/environments/kubeflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

import logging
import os

from typing_extensions import override

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.utilities.imports import _raise_enterprise_not_available

log = logging.getLogger(__name__)

Expand All @@ -33,20 +33,28 @@ class KubeflowEnvironment(ClusterEnvironment):

"""

def __init__(self) -> None:
_raise_enterprise_not_available()
from pytorch_lightning_enterprise.plugins.environments.kubeflow import (
KubeflowEnvironment as EnterpriseKubeflowEnvironment,
)

self.kubeflow_impl = EnterpriseKubeflowEnvironment()

@property
@override
def creates_processes_externally(self) -> bool:
return True
return self.kubeflow_impl.creates_processes_externally

@property
@override
def main_address(self) -> str:
return os.environ["MASTER_ADDR"]
return self.kubeflow_impl.main_address

@property
@override
def main_port(self) -> int:
return int(os.environ["MASTER_PORT"])
return self.kubeflow_impl.main_port

@staticmethod
@override
Expand All @@ -55,24 +63,24 @@ def detect() -> bool:

@override
def world_size(self) -> int:
return int(os.environ["WORLD_SIZE"])
return self.kubeflow_impl.world_size()

@override
def set_world_size(self, size: int) -> None:
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
return self.kubeflow_impl.set_world_size(size)

@override
def global_rank(self) -> int:
return int(os.environ["RANK"])
return self.kubeflow_impl.global_rank()

@override
def set_global_rank(self, rank: int) -> None:
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
return self.kubeflow_impl.set_global_rank(rank)

@override
def local_rank(self) -> int:
return 0
return self.kubeflow_impl.local_rank()

@override
def node_rank(self) -> int:
return self.global_rank()
return self.kubeflow_impl.node_rank()
Loading
Loading