diff --git a/src/lightning/fabric/accelerators/__init__.py b/src/lightning/fabric/accelerators/__init__.py index 35d2cc6eb17b0..54b6ac16dc992 100644 --- a/src/lightning/fabric/accelerators/__init__.py +++ b/src/lightning/fabric/accelerators/__init__.py @@ -16,7 +16,12 @@ from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401 from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401 +from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE _ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators" ACCELERATOR_REGISTRY = _AcceleratorRegistry() call_register_accelerators(ACCELERATOR_REGISTRY, _ACCELERATORS_BASE_MODULE) +if _LIGHTNING_XPU_AVAILABLE and "xpu" not in ACCELERATOR_REGISTRY: + from lightning_xpu.fabric import XPUAccelerator + + XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index abcb8f195abcb..bebf197581ba8 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -24,12 +24,15 @@ from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS from lightning.fabric.strategies import STRATEGY_REGISTRY from lightning.fabric.utilities.device_parser import _parse_gpu_ids +from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE _log = logging.getLogger(__name__) _CLICK_AVAILABLE = RequirementCache("click") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ["cpu", "gpu", "cuda", "mps", "tpu"] +if _LIGHTNING_XPU_AVAILABLE: + _SUPPORTED_ACCELERATORS.append("xpu") def _get_supported_strategies() -> List[str]: @@ -146,13 +149,17 @@ def _set_env_variables(args: Namespace) -> None: def _get_num_processes(accelerator: str, devices: str) -> int: """Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" if accelerator == "gpu": - parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) + parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True, include_xpu=True) elif accelerator == "cuda": parsed_devices = CUDAAccelerator.parse_devices(devices) elif accelerator == "mps": parsed_devices = MPSAccelerator.parse_devices(devices) elif accelerator == "tpu": raise ValueError("Launching processes for TPU through the CLI is not supported.") + elif accelerator == "xpu": + from lightning_xpu.fabric import XPUAccelerator + + parsed_devices = XPUAccelerator.parse_devices(devices) else: return CPUAccelerator.parse_devices(devices) return len(parsed_devices) if parsed_devices is not None else 0 diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index f99418185712f..931e996344e20 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -63,7 +63,7 @@ from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy from lightning.fabric.utilities import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.device_parser import _determine_root_gpu_device -from lightning.fabric.utilities.imports import _IS_INTERACTIVE +from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _LIGHTNING_XPU_AVAILABLE _PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO] _PLUGIN_INPUT = Union[_PLUGIN, str] @@ -288,6 +288,13 @@ def _check_config_and_set_final_flags( f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" + if self._strategy_flag.parallel_devices[0].type == "xpu": + if self._accelerator_flag and self._accelerator_flag not in ("auto", "xpu", "gpu"): + raise ValueError( + f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) + self._accelerator_flag = "xpu" self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: @@ -313,6 +320,12 @@ def _choose_auto_accelerator(self) -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _LIGHTNING_XPU_AVAILABLE: + from lightning_xpu.fabric import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" + return "cpu" @staticmethod @@ -321,6 +334,11 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _LIGHTNING_XPU_AVAILABLE: + from lightning_xpu.fabric import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" raise RuntimeError("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -378,8 +396,15 @@ def _choose_strategy(self) -> Union[Strategy, str]: return "ddp" if len(self._parallel_devices) <= 1: # TODO: Change this once gpu accelerator was renamed to cuda accelerator - if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + supported_accelerators = [CUDAAccelerator, MPSAccelerator] + supported_accelerators_str = ["cuda", "gpu", "mps"] + if _LIGHTNING_XPU_AVAILABLE: + from lightning_xpu.fabric import XPUAccelerator + + supported_accelerators.append(XPUAccelerator) + supported_accelerators_str.append("xpu") + if isinstance(self._accelerator_flag, tuple(supported_accelerators)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in tuple(supported_accelerators_str) ): device = _determine_root_gpu_device(self._parallel_devices) else: @@ -462,7 +487,11 @@ def _check_and_init_precision(self) -> Precision: if self._precision_input == "16-mixed" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = "cuda" + if self._accelerator_flag == "cpu": + device = "cpu" + elif self._accelerator_flag == "xpu": + device = "xpu" if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index c277fc386380d..b8c639a8b731f 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -115,9 +115,12 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() - # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() - with ctx: + if self.root_device.type == "cuda": + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + with ctx: + return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) + else: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) def module_to_device(self, module: Module) -> None: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 45453eca0a4cf..a4adbf14c3dae 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -16,7 +16,7 @@ import logging import os import platform -from contextlib import contextmanager +from contextlib import contextmanager, suppress from itertools import chain from pathlib import Path from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union @@ -33,10 +33,14 @@ from lightning.fabric.strategies.registry import _StrategyRegistry from lightning.fabric.strategies.strategy import _Sharded from lightning.fabric.utilities.distributed import log +from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH +if _LIGHTNING_XPU_AVAILABLE: + from lightning_xpu.fabric import XPUAccelerator + _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") if TYPE_CHECKING and _DEEPSPEED_AVAILABLE: import deepspeed @@ -215,7 +219,8 @@ def __init__( contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory. Not supported by all models. - synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. + synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` or :func:`torch.xpu.synchronize` + at each checkpoint boundary. load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards @@ -472,6 +477,9 @@ def load_checkpoint( optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)])) torch.cuda.empty_cache() + if _LIGHTNING_XPU_AVAILABLE: + XPUAccelerator.teardown() + _, client_state = engine.load_checkpoint( path, tag="checkpoint", @@ -567,10 +575,10 @@ def _initialize_engine( return deepspeed_engine, deepspeed_optimizer def _setup_distributed(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + if not isinstance(self.accelerator, CUDAAccelerator) and not isinstance(self.accelerator, XPUAccelerator): raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + "The DeepSpeed strategy is only supported on CUDA/Intel(R) GPUs but" + " `{self.accelerator.__class__.__name__}` is used." ) reset_seed() self._set_world_ranks() diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 65e363cb06d65..eca550d344e21 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -16,6 +16,7 @@ import lightning.fabric.accelerators as accelerators # avoid circular dependency from lightning.fabric.plugins.environments.torchelastic import TorchElasticEnvironment from lightning.fabric.utilities.exceptions import MisconfigurationException +from lightning.fabric.utilities.imports import _LIGHTNING_XPU_AVAILABLE from lightning.fabric.utilities.types import _DEVICE @@ -49,6 +50,7 @@ def _parse_gpu_ids( gpus: Optional[Union[int, str, List[int]]], include_cuda: bool = False, include_mps: bool = False, + include_xpu: bool = False, ) -> Optional[List[int]]: """ Parses the GPU IDs given in the format as accepted by the @@ -62,6 +64,7 @@ def _parse_gpu_ids( Any int N > 0 indicates that GPUs [0..N) should be used. include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing. include_mps: A boolean value indicating whether to include MPS devices for GPU parsing. + include_xpu: A boolean value indicating whether to include Intel GPU devices for GPU parsing. Returns: A list of GPUs to be used or ``None`` if no GPUs were requested @@ -71,7 +74,7 @@ def _parse_gpu_ids( If no GPUs are available but the value of gpus variable indicates request for GPUs .. note:: - ``include_cuda`` and ``include_mps`` default to ``False`` so that you only + ``include_cuda``, ``include_mps`` and ``include_xpu`` default to ``False`` so that you only have to specify which device type to use and all other devices are not disabled. """ # Check that gpus param is None, Int, String or Sequence of Ints @@ -84,14 +87,17 @@ def _parse_gpu_ids( # We know the user requested GPUs therefore if some of the # requested GPUs are not available an exception is thrown. gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) + gpus = _normalize_parse_gpu_input_to_list( + gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") if ( TorchElasticEnvironment.detect() and len(gpus) != 1 - and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 + and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)) + == 1 ): # Omit sanity check on torchelastic because by default it shows one visible GPU per process return gpus @@ -99,7 +105,7 @@ def _parse_gpu_ids( # Check that GPUs are unique. Duplicate GPUs are not supported by the backend. _check_unique(gpus) - return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) + return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: @@ -112,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _sanitize_gpu_ids( + gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -126,9 +134,11 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: MisconfigurationException: If machine has fewer available GPUs than requested. """ - if sum((include_cuda, include_mps)) == 0: + if sum((include_cuda, include_mps, include_xpu)) == 0: raise ValueError("At least one gpu type should be specified!") - all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + all_available_gpus = _get_all_available_gpus( + include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) for gpu in gpus: if gpu not in all_available_gpus: raise MisconfigurationException( @@ -138,7 +148,10 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool + gpus: Union[int, List[int], Tuple[int, ...]], + include_cuda: bool, + include_mps: bool, + include_xpu: bool, ) -> Optional[List[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): @@ -148,19 +161,26 @@ def _normalize_parse_gpu_input_to_list( if not gpus: # gpus==0 return None if gpus == -1: - return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _get_all_available_gpus( + include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """ Returns: A list of all available GPUs """ cuda_gpus = accelerators.cuda._get_all_visible_cuda_devices() if include_cuda else [] mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else [] - return cuda_gpus + mps_gpus + xpu_gpus = [] + if _LIGHTNING_XPU_AVAILABLE: + import lightning_xpu.fabric as accelerator_xpu + + xpu_gpus += accelerator_xpu._get_all_visible_xpu_devices() if include_xpu else [] + return cuda_gpus + mps_gpus + xpu_gpus def _check_unique(device_ids: List[int]) -> None: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index c92519bc1c49e..4dc7b01392594 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -217,7 +217,7 @@ def _init_dist_connection( Args: cluster_environment: ``ClusterEnvironment`` instance - torch_distributed_backend: Backend to use (includes `nccl` and `gloo`) + torch_distributed_backend: Backend to use (includes `nccl`, `gloo` and `ccl`) global_rank: Rank of the current process world_size: Number of processes in the group kwargs: Kwargs for ``init_process_group`` @@ -248,7 +248,12 @@ def _init_dist_connection( def _get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" + if device.type == "cuda": + return "nccl" + elif device.type == "xpu": + return "ccl" + else: + return "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index d4a1cb029ea36..56a8dfdb07fce 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -16,7 +16,7 @@ import platform import sys -from lightning_utilities.core.imports import compare_version +from lightning_utilities.core.imports import compare_version, RequirementCache _IS_WINDOWS = platform.system() == "Windows" @@ -25,8 +25,10 @@ # 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383 _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) -_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0") -_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0") -_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0") +_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0", use_base_version=True) +_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True) +_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True) _TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) _TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1 + +_LIGHTNING_XPU_AVAILABLE = RequirementCache("lightning-xpu") diff --git a/src/lightning/pytorch/_graveyard/xpu.py b/src/lightning/pytorch/_graveyard/xpu.py new file mode 100644 index 0000000000000..6d5a6af85559d --- /dev/null +++ b/src/lightning/pytorch/_graveyard/xpu.py @@ -0,0 +1,34 @@ +import sys +from typing import Any + +import lightning.pytorch as pl + + +def _patch_sys_modules() -> None: + self = sys.modules[__name__] + sys.modules["lightning.pytorch.accelerators.xpu"] = self + + +class XPUAccelerator: + auto_device_count = ... + get_parallel_devices = ... + is_available = ... + parse_devices = ... + setup_device = ... + teardown = ... + + def __init__(self, *_: Any, **__: Any) -> None: + raise NotImplementedError( + "The `XPUAccelerator` class has been moved to an external package." + " Install the extension package as `pip install lightning-xpu`" + " and import with `from lightning_xpu.pytorch import XPUAccelerator`." + " Please see: https://github.com/Lightning-AI/lightning-XPU for more details." + ) + + +def _patch_classes() -> None: + setattr(pl.accelerators, "XPUAccelerator", XPUAccelerator) + + +_patch_sys_modules() +_patch_classes() diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 164deb1e64c73..c60566a79caf4 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -184,9 +184,12 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") - # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() - with ctx: + if self.root_device.type == "cuda": + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + with ctx: + return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) + else: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) def setup_distributed(self) -> None: diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 853bbd1b32fd7..c1e1ba1026fab 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -240,7 +240,8 @@ def __init__( contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory. Not supported by all models. - synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. + synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` or :func:`torch.xpu.synchronize` + at each checkpoint boundary. load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 42f46cd75047d..f3c683ce9dd18 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -66,6 +66,7 @@ _LIGHTNING_COLOSSALAI_AVAILABLE, _LIGHTNING_GRAPHCORE_AVAILABLE, _LIGHTNING_HABANA_AVAILABLE, + _LIGHTNING_XPU_AVAILABLE, ) from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn @@ -348,6 +349,11 @@ def _choose_auto_accelerator(self) -> str: if HPUAccelerator.is_available(): return "hpu" + if _LIGHTNING_XPU_AVAILABLE: + from lightning_xpu.pytorch import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" if MPSAccelerator.is_available(): return "mps" if CUDAAccelerator.is_available(): @@ -360,6 +366,11 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _LIGHTNING_XPU_AVAILABLE: + from lightning_xpu.pytorch import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" raise MisconfigurationException("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -436,6 +447,12 @@ def _choose_strategy(self) -> Union[Strategy, str]: from lightning_habana import SingleHPUStrategy return SingleHPUStrategy(device=torch.device("hpu")) + if self._accelerator_flag == "xpu" and not _LIGHTNING_XPU_AVAILABLE: + raise ImportError( + "You have asked for XPU but you miss install related integration." + " Please run `pip install lightning-xpu` or see for further instructions" + " in https://github.com/Lightning-AI/lightning-XPU/." + ) if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator): if self._parallel_devices and len(self._parallel_devices) > 1: return XLAStrategy.strategy_name @@ -705,6 +722,13 @@ def _register_external_accelerators_and_strategies() -> None: if "hpu_single" not in StrategyRegistry: SingleHPUStrategy.register_strategies(StrategyRegistry) + if _LIGHTNING_XPU_AVAILABLE: + from lightning_xpu.pytorch import XPUAccelerator + + # TODO: Prevent registering multiple times + if "xpu" not in AcceleratorRegistry: + XPUAccelerator.register_accelerators(AcceleratorRegistry) + if _LIGHTNING_GRAPHCORE_AVAILABLE: from lightning_graphcore import IPUAccelerator, IPUStrategy diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 36f9c27e70983..65c3789b30a0f 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -28,7 +28,11 @@ XLAProfiler, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _LIGHTNING_GRAPHCORE_AVAILABLE, _LIGHTNING_HABANA_AVAILABLE +from lightning.pytorch.utilities.imports import ( + _LIGHTNING_GRAPHCORE_AVAILABLE, + _LIGHTNING_HABANA_AVAILABLE, + _LIGHTNING_XPU_AVAILABLE, +) from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn @@ -178,11 +182,24 @@ def _log_device_info(trainer: "pl.Trainer") -> None: hpu_available = False rank_zero_info(f"HPU available: {hpu_available}, using: {num_hpus} HPUs") + if _LIGHTNING_XPU_AVAILABLE: + from lightning_xpu.pytorch import XPUAccelerator + + num_xpus = trainer.num_devices if isinstance(trainer.accelerator, XPUAccelerator) else 0 + xpu_available = XPUAccelerator.is_available() + else: + num_xpus = 0 + xpu_available = False + rank_zero_info(f"XPU available: {xpu_available}, using: {num_xpus} XPUs") + if ( CUDAAccelerator.is_available() and not isinstance(trainer.accelerator, CUDAAccelerator) or MPSAccelerator.is_available() and not isinstance(trainer.accelerator, MPSAccelerator) + or _LIGHTNING_XPU_AVAILABLE + and XPUAccelerator.is_available() + and not isinstance(trainer.accelerator, XPUAccelerator) ): rank_zero_warn( "GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.", diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 74545a15acf44..b6e5c763ce32e 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -147,7 +147,7 @@ def __init__( precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). - Can be used on CPU, GPU, TPUs, HPUs or IPUs. + Can be used on CPU, GPU, TPUs, HPUs, IPUs or XPUs. Default: ``'32-true'``. logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index a28c2dd276465..7e6310983c569 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -29,4 +29,5 @@ _LIGHTNING_COLOSSALAI_AVAILABLE = RequirementCache("lightning-colossalai") _LIGHTNING_BAGUA_AVAILABLE = RequirementCache("lightning-bagua") _LIGHTNING_HABANA_AVAILABLE = RequirementCache("lightning-habana") +_LIGHTNING_XPU_AVAILABLE = RequirementCache("lightning-xpu") _LIGHTNING_GRAPHCORE_AVAILABLE = RequirementCache("lightning-graphcore") diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index 8ffe163c87b33..6f73dd2b838a0 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -62,7 +62,7 @@ Lightning forces the following structure to your code which makes it reusable an - Non-essential research code (logging, etc... this goes in Callbacks). - Data (use PyTorch DataLoaders or organize them into a LightningDataModule). -Once you do this, you can train on multiple-GPUs, TPUs, CPUs, IPUs, HPUs and even in 16-bit precision without changing your code! +Once you do this, you can train on multiple-GPUs, TPUs, CPUs, IPUs, HPUs, XPUs and even in 16-bit precision without changing your code! [Get started in just 15 minutes](https://lightning.ai/docs/pytorch/latest/starter/introduction.html) @@ -70,7 +70,7 @@ ______________________________________________________________________ ## Continuous Integration -Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, and HPUs and against major Python and PyTorch versions. +Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, HPUs and XPUs and against major Python and PyTorch versions.
Current build statuses