Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE

_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE)
if _LIGHTNING_XPU_AVAILABLE and "xpu" not in ACCELERATOR_REGISTRY:
from lightning_xpu.fabric import XPUAccelerator

XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY)
11 changes: 9 additions & 2 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE

_log = logging.getLogger(__name__)

_CLICK_AVAILABLE = RequirementCache("click")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_ACCELERATORS = ["cpu", "gpu", "cuda", "mps", "tpu"]
if _LIGHTNING_XPU_AVAILABLE:
_SUPPORTED_ACCELERATORS.append("xpu")


def _get_supported_strategies() -> List[str]:
Expand Down Expand Up @@ -146,13 +149,17 @@ def _set_env_variables(args: Namespace) -> None:
def _get_num_processes(accelerator: str, devices: str) -> int:
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
if accelerator == "gpu":
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True, include_xpu=True)
elif accelerator == "cuda":
parsed_devices = CUDAAccelerator.parse_devices(devices)
elif accelerator == "mps":
parsed_devices = MPSAccelerator.parse_devices(devices)
elif accelerator == "tpu":
raise ValueError("Launching processes for TPU through the CLI is not supported.")
elif accelerator == "xpu":
from lightning_xpu.fabric import XPUAccelerator

parsed_devices = XPUAccelerator.parse_devices(devices)
else:
return CPUAccelerator.parse_devices(devices)
return len(parsed_devices) if parsed_devices is not None else 0
Expand Down
37 changes: 33 additions & 4 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
from lightning.fabric.utilities import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _LIGHTNING_XPU_AVAILABLE

_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN_INPUT = Union[_PLUGIN, str]
Expand Down Expand Up @@ -288,6 +288,13 @@ def _check_config_and_set_final_flags(
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "cuda"
if self._strategy_flag.parallel_devices[0].type == "xpu":
if self._accelerator_flag and self._accelerator_flag not in ("auto", "xpu", "gpu"):
raise ValueError(
f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
)
self._accelerator_flag = "xpu"
self._parallel_devices = self._strategy_flag.parallel_devices

def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
Expand All @@ -313,6 +320,12 @@ def _choose_auto_accelerator(self) -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if _LIGHTNING_XPU_AVAILABLE:
from lightning_xpu.fabric import XPUAccelerator

if XPUAccelerator.is_available():
return "xpu"

return "cpu"

@staticmethod
Expand All @@ -321,6 +334,11 @@ def _choose_gpu_accelerator_backend() -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if _LIGHTNING_XPU_AVAILABLE:
from lightning_xpu.fabric import XPUAccelerator

if XPUAccelerator.is_available():
return "xpu"
raise RuntimeError("No supported gpu backend found!")

def _set_parallel_devices_and_init_accelerator(self) -> None:
Expand Down Expand Up @@ -378,8 +396,15 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return "ddp"
if len(self._parallel_devices) <= 1:
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
supported_accelerators = [CUDAAccelerator, MPSAccelerator]
supported_accelerators_str = ["cuda", "gpu", "mps"]
if _LIGHTNING_XPU_AVAILABLE:
from lightning_xpu.fabric import XPUAccelerator

supported_accelerators.append(XPUAccelerator)
supported_accelerators_str.append("xpu")
if isinstance(self._accelerator_flag, tuple(supported_accelerators)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in tuple(supported_accelerators_str)
):
device = _determine_root_gpu_device(self._parallel_devices)
else:
Expand Down Expand Up @@ -462,7 +487,11 @@ def _check_and_init_precision(self) -> Precision:
if self._precision_input == "16-mixed"
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
device = "cuda"
if self._accelerator_flag == "cpu":
device = "cpu"
elif self._accelerator_flag == "xpu":
device = "xpu"

if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
Expand Down
9 changes: 6 additions & 3 deletions src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,12 @@ def setup_environment(self) -> None:
def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
if self.root_device.type == "cuda":
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
else:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

def module_to_device(self, module: Module) -> None:
Expand Down
18 changes: 13 additions & 5 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import os
import platform
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
Expand All @@ -33,10 +33,14 @@
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.distributed import log
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH

if _LIGHTNING_XPU_AVAILABLE:
from lightning_xpu.fabric import XPUAccelerator

_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
import deepspeed
Expand Down Expand Up @@ -215,7 +219,8 @@ def __init__(
contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory.
Not supported by all models.

synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` or :func:`torch.xpu.synchronize`
at each checkpoint boundary.

load_full_weights: True when loading a single checkpoint file containing the model state dict
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
Expand Down Expand Up @@ -472,6 +477,9 @@ def load_checkpoint(
optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)]))

torch.cuda.empty_cache()
if _LIGHTNING_XPU_AVAILABLE:
XPUAccelerator.teardown()

_, client_state = engine.load_checkpoint(
path,
tag="checkpoint",
Expand Down Expand Up @@ -567,10 +575,10 @@ def _initialize_engine(
return deepspeed_engine, deepspeed_optimizer

def _setup_distributed(self) -> None:
if not isinstance(self.accelerator, CUDAAccelerator):
if not isinstance(self.accelerator, CUDAAccelerator) and not isinstance(self.accelerator, XPUAccelerator):
raise RuntimeError(
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
" is used."
"The DeepSpeed strategy is only supported on CUDA/Intel(R) GPUs but"
" `{self.accelerator.__class__.__name__}` is used."
)
reset_seed()
self._set_world_ranks()
Expand Down
42 changes: 31 additions & 11 deletions src/lightning/fabric/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import lightning.fabric.accelerators as accelerators # avoid circular dependency
from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE
from lightning.fabric.utilities.types import _DEVICE


Expand Down Expand Up @@ -49,6 +50,7 @@ def _parse_gpu_ids(
gpus: Optional[Union[int, str, List[int]]],
include_cuda: bool = False,
include_mps: bool = False,
include_xpu: bool = False,
) -> Optional[List[int]]:
"""
Parses the GPU IDs given in the format as accepted by the
Expand All @@ -62,6 +64,7 @@ def _parse_gpu_ids(
Any int N > 0 indicates that GPUs [0..N) should be used.
include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing.
include_mps: A boolean value indicating whether to include MPS devices for GPU parsing.
include_xpu: A boolean value indicating whether to include Intel GPU devices for GPU parsing.

Returns:
A list of GPUs to be used or ``None`` if no GPUs were requested
Expand All @@ -71,7 +74,7 @@ def _parse_gpu_ids(
If no GPUs are available but the value of gpus variable indicates request for GPUs

.. note::
``include_cuda`` and ``include_mps`` default to ``False`` so that you only
``include_cuda``, ``include_mps`` and ``include_xpu`` default to ``False`` so that you only
have to specify which device type to use and all other devices are not disabled.
"""
# Check that gpus param is None, Int, String or Sequence of Ints
Expand All @@ -84,22 +87,25 @@ def _parse_gpu_ids(
# We know the user requested GPUs therefore if some of the
# requested GPUs are not available an exception is thrown.
gpus = _normalize_parse_gpu_string_input(gpus)
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
gpus = _normalize_parse_gpu_input_to_list(
gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu
)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")

if (
TorchElasticEnvironment.detect()
and len(gpus) != 1
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu))
== 1
):
# Omit sanity check on torchelastic because by default it shows one visible GPU per process
return gpus

# Check that GPUs are unique. Duplicate GPUs are not supported by the backend.
_check_unique(gpus)

return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps)
return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)


def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
Expand All @@ -112,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in
return int(s.strip())


def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]:
def _sanitize_gpu_ids(
gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False
) -> List[int]:
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of
the GPUs is not available.

Expand All @@ -126,9 +134,11 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:
MisconfigurationException:
If machine has fewer available GPUs than requested.
"""
if sum((include_cuda, include_mps)) == 0:
if sum((include_cuda, include_mps, include_xpu)) == 0:
raise ValueError("At least one gpu type should be specified!")
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
all_available_gpus = _get_all_available_gpus(
include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu
)
for gpu in gpus:
if gpu not in all_available_gpus:
raise MisconfigurationException(
Expand All @@ -138,7 +148,10 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:


def _normalize_parse_gpu_input_to_list(
gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool
gpus: Union[int, List[int], Tuple[int, ...]],
include_cuda: bool,
include_mps: bool,
include_xpu: bool,
) -> Optional[List[int]]:
assert gpus is not None
if isinstance(gpus, (MutableSequence, tuple)):
Expand All @@ -148,19 +161,26 @@ def _normalize_parse_gpu_input_to_list(
if not gpus: # gpus==0
return None
if gpus == -1:
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)

return list(range(gpus))


def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]:
def _get_all_available_gpus(
include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False
) -> List[int]:
"""
Returns:
A list of all available GPUs
"""
cuda_gpus = accelerators.cuda._get_all_visible_cuda_devices() if include_cuda else []
mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else []
return cuda_gpus + mps_gpus
xpu_gpus = []
if _LIGHTNING_XPU_AVAILABLE:
import lightning_xpu.fabric as accelerator_xpu

xpu_gpus += accelerator_xpu._get_all_visible_xpu_devices() if include_xpu else []
return cuda_gpus + mps_gpus + xpu_gpus


def _check_unique(device_ids: List[int]) -> None:
Expand Down
9 changes: 7 additions & 2 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _init_dist_connection(

Args:
cluster_environment: ``ClusterEnvironment`` instance
torch_distributed_backend: Backend to use (includes `nccl` and `gloo`)
torch_distributed_backend: Backend to use (includes `nccl`, `gloo` and `ccl`)
global_rank: Rank of the current process
world_size: Number of processes in the group
kwargs: Kwargs for ``init_process_group``
Expand Down Expand Up @@ -248,7 +248,12 @@ def _init_dist_connection(


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
if device.type == "cuda":
return "nccl"
elif device.type == "xpu":
return "ccl"
else:
return "gloo"


class _DatasetSamplerWrapper(Dataset):
Expand Down
10 changes: 6 additions & 4 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import platform
import sys

from lightning_utilities.core.imports import compare_version
from lightning_utilities.core.imports import compare_version, RequirementCache

_IS_WINDOWS = platform.system() == "Windows"

Expand All @@ -25,8 +25,10 @@
# 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)

_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0", use_base_version=True)
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1

_LIGHTNING_XPU_AVAILABLE = RequirementCache("lightning-xpu")
Loading