diff --git a/docs/source-fabric/fundamentals/launch.rst b/docs/source-fabric/fundamentals/launch.rst index f8c0deecf4e25..50c933a7a8b54 100644 --- a/docs/source-fabric/fundamentals/launch.rst +++ b/docs/source-fabric/fundamentals/launch.rst @@ -93,8 +93,9 @@ This is essentially the same as running ``python path/to/your/script.py``, but i itself and are expected to be parsed there. Options: - --accelerator [cpu|gpu|cuda|mps|tpu] + --accelerator [cpu|gpu|cuda|mps|tpu|xpu] The hardware accelerator to run on. + Install Lightning-XPU to enable ``xpu``. --strategy [ddp|dp|deepspeed] Strategy for how to run across multiple devices. --devices TEXT Number of devices to run on (``int``), which diff --git a/docs/source-pytorch/common/index.rst b/docs/source-pytorch/common/index.rst index 738e971aec532..8e6331580ea45 100644 --- a/docs/source-pytorch/common/index.rst +++ b/docs/source-pytorch/common/index.rst @@ -17,6 +17,7 @@ ../advanced/model_parallel Train on single or multiple GPUs <../accelerators/gpu> Train on single or multiple HPUs <../integrations/hpu/index> + Train on single or multiple XPUs <../integrations/xpu/index> Train on single or multiple TPUs <../accelerators/tpu> Train on MPS <../accelerators/mps> Use a pretrained model <../advanced/pretrained> @@ -167,6 +168,13 @@ How-to Guides :col_css: col-md-4 :height: 180 +.. displayitem:: + :header: Train on single or multiple XPUs + :description: Train models faster with XPU accelerators + :button_link: ../integrations/xpu/index.html + :col_css: col-md-4 + :height: 180 + .. displayitem:: :header: Train on single or multiple TPUs :description: TTrain models faster with TPU accelerators diff --git a/docs/source-pytorch/common_usecases.rst b/docs/source-pytorch/common_usecases.rst index 2891d264d885d..7721a7c4a4836 100644 --- a/docs/source-pytorch/common_usecases.rst +++ b/docs/source-pytorch/common_usecases.rst @@ -133,6 +133,13 @@ Customize and extend Lightning for things like custom hardware or distributed st :button_link: integrations/hpu/index.html :height: 100 +.. displayitem:: + :header: Train on single or multiple XPUs + :description: Train models faster with XPUs. + :col_css: col-md-12 + :button_link: integrations/xpu/index.html + :height: 100 + .. displayitem:: :header: Train on single or multiple TPUs :description: Train models faster with TPUs. diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index 0baba58ee0a38..f713648528cd5 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -93,6 +93,11 @@ def _load_py_module(name: str, location: str) -> ModuleType: target_dir="docs/source-pytorch/integrations/hpu", checkout="refs/tags/1.4.0", ) +assist_local.AssistantCLI.pull_docs_files( + gh_user_repo="Lightning-AI/lightning-XPU", + target_dir="docs/source-pytorch/integrations/xpu", + checkout="tags/1.0.0", +) # Copy strategies docs as single pages assist_local.AssistantCLI.pull_docs_files( @@ -355,6 +360,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: "PIL": ("https://pillow.readthedocs.io/en/stable/", None), "torchmetrics": ("https://lightning.ai/docs/torchmetrics/stable/", None), "lightning_habana": ("https://lightning-ai.github.io/lightning-Habana/", None), + "intel-xpu": ("https://lightning-ai.github.io/lightning-XPU/", None), "tensorboardX": ("https://tensorboardx.readthedocs.io/en/stable/", None), # needed for referencing Fabric from lightning scope "lightning.fabric": ("https://lightning.ai/docs/fabric/stable/", None), diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index 93dc467b02921..c21588c63a136 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -11,6 +11,7 @@ Currently there are accelerators for: - :doc:`GPU <../accelerators/gpu>` - :doc:`TPU <../accelerators/tpu>` - :doc:`HPU <../integrations/hpu/index>` +- :doc:`XPU <../integrations/xpu/index>` - :doc:`MPS <../accelerators/mps>` The Accelerator is part of the Strategy which manages communication across multiple devices (distributed communication). @@ -31,16 +32,16 @@ Create a Custom Accelerator .. warning:: This is an :ref:`experimental ` feature. Here is how you create a new Accelerator. -Let's pretend we want to integrate the fictional XPU accelerator and we have access to its hardware through a library -``xpulib``. +Let's pretend we want to integrate the fictional YPU accelerator and we have access to its hardware through a library +``ypulib``. .. code-block:: python - import xpulib + import ypulib - class XPUAccelerator(Accelerator): - """Support for a hypothetical XPU, optimized for large-scale machine learning.""" + class YPUAccelerator(Accelerator): + """Support for a hypothetical YPU, optimized for large-scale machine learning.""" @staticmethod def parse_devices(devices: Any) -> Any: @@ -51,29 +52,29 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc @staticmethod def get_parallel_devices(devices: Any) -> Any: # Here, convert the device indices to actual device objects - return [torch.device("xpu", idx) for idx in devices] + return [torch.device("ypu", idx) for idx in devices] @staticmethod def auto_device_count() -> int: # Return a value for auto-device selection when `Trainer(devices="auto")` - return xpulib.available_devices() + return ypulib.available_devices() @staticmethod def is_available() -> bool: - return xpulib.is_available() + return ypulib.is_available() def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: # Return optional device statistics for loggers return {} -Finally, add the XPUAccelerator to the Trainer: +Finally, add the YPUAccelerator to the Trainer: .. code-block:: python from lightning.pytorch import Trainer - accelerator = XPUAccelerator() + accelerator = YPUAccelerator() trainer = Trainer(accelerator=accelerator, devices=2) @@ -89,28 +90,28 @@ If you wish to switch to a custom accelerator from the CLI without code changes, .. code-block:: python - class XPUAccelerator(Accelerator): + class YPUAccelerator(Accelerator): ... @classmethod def register_accelerators(cls, accelerator_registry): accelerator_registry.register( - "xpu", + "ypu", cls, - description=f"XPU Accelerator - optimized for large-scale machine learning.", + description=f"YPU Accelerator - optimized for large-scale machine learning.", ) Now, this is possible: .. code-block:: python - trainer = Trainer(accelerator="xpu") + trainer = Trainer(accelerator="ypu") Or if you are using the Lightning CLI, for example: .. code-block:: bash - python train.py fit --trainer.accelerator=xpu --trainer.devices=2 + python train.py fit --trainer.accelerator=ypu --trainer.devices=2 ---------- diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst index 6b5e4b12b307f..f2230fdf98039 100644 --- a/docs/source-pytorch/glossary/index.rst +++ b/docs/source-pytorch/glossary/index.rst @@ -21,6 +21,7 @@ GPU <../accelerators/gpu> Half precision <../common/precision> HPU <../integrations/hpu/index> + XPU <../integrations/xpu/index> Inference <../deploy/production_intermediate> Lightning CLI <../cli/lightning_cli> LightningDataModule <../data/datamodule> @@ -186,6 +187,13 @@ Glossary :button_link: ../integrations/hpu/index.html :height: 100 +.. displayitem:: + :header: XPU + :description: Intel® Graphics Cards for faster training + :col_css: col-md-12 + :button_link: ../integrations/xpu/index.html + :height: 100 + .. displayitem:: :header: Inference :description: Making predictions by applying a trained model to unlabeled examples diff --git a/docs/source-pytorch/integrations/xpu/index.rst b/docs/source-pytorch/integrations/xpu/index.rst new file mode 100644 index 0000000000000..3fb22d6e36541 --- /dev/null +++ b/docs/source-pytorch/integrations/xpu/index.rst @@ -0,0 +1,40 @@ +.. _xpu: + +Accelerator: XPU training +========================= + +.. raw:: html + +
+
+ +.. Add callout items below this line + +.. displayitem:: + :header: Basic + :description: Learn the basics of single and multi-XPU core training. + :col_css: col-md-4 + :button_link: basic.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Intermediate + :description: Enable state-of-the-art scaling with advanced mix-precision settings. + :col_css: col-md-4 + :button_link: intermediate.html + :height: 150 + :tag: intermediate + +.. displayitem:: + :header: Advanced + :description: Explore state-of-the-art scaling with additional advanced configurations. + :col_css: col-md-4 + :button_link: advanced.html + :height: 150 + :tag: advanced + +.. raw:: html + +
+
diff --git a/docs/source-pytorch/levels/advanced_level_23.rst b/docs/source-pytorch/levels/advanced_level_23.rst new file mode 100644 index 0000000000000..895f4c538398c --- /dev/null +++ b/docs/source-pytorch/levels/advanced_level_23.rst @@ -0,0 +1,37 @@ +:orphan: + +###################### +Level 19: Explore XPUs +###################### + +Explore Intel® Graphics Cards (XPU) for model scaling. + +---- + +.. raw:: html + +
+
+ +.. Add callout items below this line + +.. displayitem:: + :header: Train models on XPUs + :description: Learn the basics of single and multi-XPU core training. + :col_css: col-md-6 + :button_link: ../integrations/xpu/basic.html + :height: 150 + :tag: basic + +.. displayitem:: + :header: Optimize models training on XPUs + :description: Enable state-of-the-art scaling with advanced mixed-precision settings. + :col_css: col-md-6 + :button_link: ../integrations/xpu/intermediate.html + :height: 150 + :tag: intermediate + +.. raw:: html + +
+
diff --git a/examples/pytorch/xpu/mnist_sample.py b/examples/pytorch/xpu/mnist_sample.py new file mode 100644 index 0000000000000..edc6396d35c45 --- /dev/null +++ b/examples/pytorch/xpu/mnist_sample.py @@ -0,0 +1,69 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from lightning.pytorch import LightningModule +from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule +from torch.nn import functional as F + + +class LitClassifier(LightningModule): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(28 * 28, 10) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1))) + + def training_step(self, batch, batch_idx): + x, y = batch + return F.cross_entropy(self(x), y) + + def validation_step(self, batch, batch_idx): + x, y = batch + probs = self(x) + acc = self.accuracy(probs, y) + self.log("val_acc", acc) + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + acc = self.accuracy(logits, y) + self.log("test_acc", acc) + + @staticmethod + def accuracy(logits, y): + return torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + +if __name__ == "__main__": + cli = LightningCLI( + LitClassifier, + MNISTDataModule, + trainer_defaults={ + "accelerator": "gpu", + "devices": 2, + "max_epochs": 1, + }, + run=False, + save_config_kwargs={"overwrite": True}, + ) + + # Run the model ⚡ + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.validate(cli.model, datamodule=cli.datamodule) + cli.trainer.test(cli.model, datamodule=cli.datamodule) diff --git a/requirements/_integrations/accelerators.txt b/requirements/_integrations/accelerators.txt index 90c72bedb2cdc..db956d90dabd7 100644 --- a/requirements/_integrations/accelerators.txt +++ b/requirements/_integrations/accelerators.txt @@ -1,2 +1,5 @@ # validation accelerator connectors lightning-habana >=1.2.0, <1.3.0 + +# validation XPU connectors +lightning-xpu >=0.1.0 diff --git a/src/lightning/fabric/accelerators/__init__.py b/src/lightning/fabric/accelerators/__init__.py index 3d4b43f75c762..654c928a64ae3 100644 --- a/src/lightning/fabric/accelerators/__init__.py +++ b/src/lightning/fabric/accelerators/__init__.py @@ -22,3 +22,10 @@ ACCELERATOR_REGISTRY = _AcceleratorRegistry() _register_classes(ACCELERATOR_REGISTRY, "register_accelerators", sys.modules[__name__], Accelerator) + +from lightning.fabric.utilities.imports import _lightning_xpu_available + +if _lightning_xpu_available() and "xpu" not in ACCELERATOR_REGISTRY: + from lightning_xpu.fabric import XPUAccelerator + + XPUAccelerator.register_accelerators(ACCELERATOR_REGISTRY) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index d8c6fe47b6630..4e8793aecc722 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -29,6 +29,7 @@ from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args from lightning.fabric.utilities.device_parser import _parse_gpu_ids from lightning.fabric.utilities.distributed import _suggested_max_num_threads +from lightning.fabric.utilities.imports import _lightning_xpu_available from lightning.fabric.utilities.load import _load_distributed_checkpoint _log = logging.getLogger(__name__) @@ -36,7 +37,9 @@ _CLICK_AVAILABLE = RequirementCache("click") _LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ["cpu", "gpu", "cuda", "mps", "tpu"] +if _lightning_xpu_available(): + _SUPPORTED_ACCELERATORS.append("xpu") def _get_supported_strategies() -> List[str]: @@ -209,13 +212,17 @@ def _set_env_variables(args: Namespace) -> None: def _get_num_processes(accelerator: str, devices: str) -> int: """Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" if accelerator == "gpu": - parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) + parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True, include_xpu=True) elif accelerator == "cuda": parsed_devices = CUDAAccelerator.parse_devices(devices) elif accelerator == "mps": parsed_devices = MPSAccelerator.parse_devices(devices) elif accelerator == "tpu": raise ValueError("Launching processes for TPU through the CLI is not supported.") + elif accelerator == "xpu": + from lightning_xpu.fabric import XPUAccelerator + + parsed_devices = XPUAccelerator.parse_devices(devices) else: return CPUAccelerator.parse_devices(devices) return len(parsed_devices) if parsed_devices is not None else 0 diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index f677893100351..e08dbbaa8c10b 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -65,7 +65,7 @@ from lightning.fabric.strategies.model_parallel import ModelParallelStrategy from lightning.fabric.utilities import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.device_parser import _determine_root_gpu_device -from lightning.fabric.utilities.imports import _IS_INTERACTIVE +from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _lightning_xpu_available _PLUGIN_INPUT = Union[Precision, ClusterEnvironment, CheckpointIO] @@ -293,6 +293,13 @@ def _check_config_and_set_final_flags( f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" + if self._strategy_flag.parallel_devices[0].type == "xpu": + if self._accelerator_flag and self._accelerator_flag not in ("auto", "xpu", "gpu"): + raise ValueError( + f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) + self._accelerator_flag = "xpu" self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: @@ -321,6 +328,12 @@ def _choose_auto_accelerator(self) -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" + return "cpu" @staticmethod @@ -329,6 +342,11 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" raise RuntimeError("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -399,8 +417,15 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: - if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + supported_accelerators = [CUDAAccelerator, MPSAccelerator] + supported_accelerators_str = ["cuda", "gpu", "mps"] + if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + + supported_accelerators.append(XPUAccelerator) + supported_accelerators_str.append("xpu") + if isinstance(self._accelerator_flag, tuple(supported_accelerators)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in tuple(supported_accelerators_str) ): device = _determine_root_gpu_device(self._parallel_devices) else: @@ -491,7 +516,12 @@ def _check_and_init_precision(self) -> Precision: if self._precision_input == "16-mixed" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = "cuda" + if self._accelerator_flag == "cpu": + device = "cpu" + elif self._accelerator_flag == "xpu": + device = "xpu" + return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 0ec5df1a6b0ae..52289aaf441d9 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -123,10 +123,17 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() - # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() - with ctx: + ctx = None + if self.root_device.type == "cuda": + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + if self.root_device.type == "xpu": + ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext() + if ctx is None: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) + else: + with ctx: + return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) @override def module_to_device(self, module: Module) -> None: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 2a1a1272b498e..e9010b0bea5f0 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -34,6 +34,7 @@ from lightning.fabric.strategies.registry import _StrategyRegistry from lightning.fabric.strategies.strategy import _Sharded from lightning.fabric.utilities.distributed import log +from lightning.fabric.utilities.imports import _lightning_xpu_available from lightning.fabric.utilities.load import _move_state_into from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed @@ -42,6 +43,9 @@ if TYPE_CHECKING: from deepspeed import DeepSpeedEngine +if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator + _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") @@ -218,7 +222,8 @@ def __init__( contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory. Not supported by all models. - synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. + synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` or :func:`torch.xpu.synchronize` + at each checkpoint boundary. load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards @@ -503,6 +508,10 @@ def load_checkpoint( optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) torch.cuda.empty_cache() + with suppress(AttributeError): + if _lightning_xpu_available(): + XPUAccelerator.teardown() + _, client_state = engine.load_checkpoint( path, tag="checkpoint", @@ -612,10 +621,14 @@ def _initialize_engine( @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + ds_support = False + if isinstance(self.accelerator, CUDAAccelerator): + ds_support = True + if _lightning_xpu_available() and isinstance(self.accelerator, XPUAccelerator): + ds_support = True + if not ds_support: raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + f"The DeepSpeed strategy is only supported on CUDA/Intel(R) GPUs but `{self.accelerator.__class__.__name__}` is used." ) super().setup_environment() diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 14a063f28f336..cacca2c057f02 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -29,12 +29,15 @@ from lightning.fabric.strategies.launchers.launcher import _Launcher from lightning.fabric.utilities.apply_func import move_data_to_device from lightning.fabric.utilities.distributed import _set_num_threads_if_needed -from lightning.fabric.utilities.imports import _IS_INTERACTIVE +from lightning.fabric.utilities.imports import _IS_INTERACTIVE, _lightning_xpu_available from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states if TYPE_CHECKING: from lightning.fabric.strategies import ParallelStrategy +if _lightning_xpu_available(): + from lightning_xpu.fabric import XPUAccelerator, is_xpu_initialized + class _MultiProcessingLauncher(_Launcher): r"""Launches processes that run a given function in parallel, and joins them all at the end. @@ -96,6 +99,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """ if self._start_method in ("fork", "forkserver"): _check_bad_cuda_fork() + if _lightning_xpu_available() and XPUAccelerator.is_available(): + _check_bad_xpu_fork() if self._start_method == "spawn": _check_missing_main_guard() @@ -247,3 +252,23 @@ def main(): """ ) raise RuntimeError(message) + + +def _check_bad_xpu_fork() -> None: + """Checks whether it is safe to fork and initialize XPU in the new processes, and raises an exception if not. + + The error message replaces PyTorch's 'Cannot re-initialize XPU in forked subprocess' with helpful advice for + Lightning users. + + """ + if not is_xpu_initialized(): + return + + message = ( + "Lightning can't create new processes if XPU is already initialized. Did you manually call" + " `torch.xpu.*` functions, have moved the model to the device, or allocated memory on the GPU any" + " other way? Please remove any such calls, or change the selected strategy." + ) + if _IS_INTERACTIVE: + message += " You will have to restart the Python kernel." + raise RuntimeError(message) diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 16965d944caec..c9f4330033fd1 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -16,6 +16,7 @@ import torch from lightning.fabric.utilities.exceptions import MisconfigurationException +from lightning.fabric.utilities.imports import _lightning_xpu_available from lightning.fabric.utilities.types import _DEVICE @@ -49,6 +50,7 @@ def _parse_gpu_ids( gpus: Optional[Union[int, str, List[int]]], include_cuda: bool = False, include_mps: bool = False, + include_xpu: bool = False, ) -> Optional[List[int]]: """Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`. @@ -60,6 +62,7 @@ def _parse_gpu_ids( Any int N > 0 indicates that GPUs [0..N) should be used. include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing. include_mps: A boolean value indicating whether to include MPS devices for GPU parsing. + include_xpu: A boolean value indicating whether to include Intel GPU devices for GPU parsing. Returns: A list of GPUs to be used or ``None`` if no GPUs were requested @@ -69,7 +72,7 @@ def _parse_gpu_ids( If no GPUs are available but the value of gpus variable indicates request for GPUs .. note:: - ``include_cuda`` and ``include_mps`` default to ``False`` so that you only + ``include_cuda``, ``include_mps`` and ``include_xpu`` default to ``False`` so that you only have to specify which device type to use and all other devices are not disabled. """ @@ -83,7 +86,9 @@ def _parse_gpu_ids( # We know the user requested GPUs therefore if some of the # requested GPUs are not available an exception is thrown. gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) + gpus = _normalize_parse_gpu_input_to_list( + gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") @@ -91,7 +96,8 @@ def _parse_gpu_ids( torch.distributed.is_available() and torch.distributed.is_torchelastic_launched() and len(gpus) != 1 - and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 + and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)) + == 1 ): # Omit sanity check on torchelastic because by default it shows one visible GPU per process return gpus @@ -99,7 +105,7 @@ def _parse_gpu_ids( # Check that GPUs are unique. Duplicate GPUs are not supported by the backend. _check_unique(gpus) - return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) + return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: @@ -112,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _sanitize_gpu_ids( + gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -127,9 +135,11 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: If machine has fewer available GPUs than requested. """ - if sum((include_cuda, include_mps)) == 0: + if sum((include_cuda, include_mps, include_xpu)) == 0: raise ValueError("At least one gpu type should be specified!") - all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + all_available_gpus = _get_all_available_gpus( + include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) for gpu in gpus: if gpu not in all_available_gpus: raise MisconfigurationException( @@ -139,7 +149,10 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool + gpus: Union[int, List[int], Tuple[int, ...]], + include_cuda: bool, + include_mps: bool, + include_xpu: bool, ) -> Optional[List[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): @@ -149,12 +162,14 @@ def _normalize_parse_gpu_input_to_list( if not gpus: # gpus==0 return None if gpus == -1: - return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _get_all_available_gpus( + include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """ Returns: A list of all available GPUs @@ -164,7 +179,12 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals cuda_gpus = _get_all_visible_cuda_devices() if include_cuda else [] mps_gpus = _get_all_available_mps_gpus() if include_mps else [] - return cuda_gpus + mps_gpus + xpu_gpus = [] + if _lightning_xpu_available(): + from lightning_xpu.fabric import _get_all_visible_xpu_devices + + xpu_gpus += _get_all_visible_xpu_devices() if include_xpu else [] + return cuda_gpus + mps_gpus + xpu_gpus def _check_unique(device_ids: List[int]) -> None: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 30bfe4e254a07..8b6d6c4c97716 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -269,7 +269,7 @@ def _init_dist_connection( Args: cluster_environment: ``ClusterEnvironment`` instance - torch_distributed_backend: Backend to use (includes `nccl` and `gloo`) + torch_distributed_backend: Backend to use (includes `nccl`, `gloo` and `ccl`) global_rank: Rank of the current process world_size: Number of processes in the group kwargs: Kwargs for ``init_process_group`` @@ -301,7 +301,12 @@ def _init_dist_connection( def _get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" + if device.type == "cuda": + return "nccl" + elif device.type == "xpu": + return "ccl" + else: + return "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 46374e23ad2b5..367c6906831c6 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -17,7 +17,7 @@ import platform import sys -from lightning_utilities.core.imports import compare_version +from lightning_utilities.core.imports import RequirementCache, compare_version _IS_WINDOWS = platform.system() == "Windows" @@ -26,14 +26,34 @@ # 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383 _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) -_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0") -_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0") -_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0") +_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) +_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0", use_base_version=True) +_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0", use_base_version=True) _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True) -_TORCH_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0") and not _TORCH_GREATER_EQUAL_2_1 +_TORCH_EQUAL_2_0 = ( + compare_version("torch", operator.ge, "2.0.0", use_base_version=True) and not _TORCH_GREATER_EQUAL_2_1 +) _PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) _UTILITIES_GREATER_EQUAL_0_10 = compare_version("lightning_utilities", operator.ge, "0.10.0") + + +@functools.lru_cache(maxsize=128) +def _try_import_module(module_name: str) -> bool: + try: + __import__(module_name) + return True + # added also AttributeError fro case of impoerts like pl.LightningModule + except (ImportError, AttributeError) as err: + rank_zero_warn(f"Import of {module_name} package failed for some compatibility issues: \n{err}") + return False + + +@functools.lru_cache(maxsize=1) +def _lightning_xpu_available() -> bool: + # This is defined as a function instead of a constant to avoid circular imports, because `lightning_xpu` + # also imports Lightning + return bool(RequirementCache("lightning-xpu")) and _try_import_module("lightning_xpu") diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index b274bce88fcdf..8e9153fde9c3f 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -14,6 +14,7 @@ max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min +from lightning.fabric.utilities.imports import _lightning_xpu_available def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: @@ -56,6 +57,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + if _lightning_xpu_available() and torch.xpu.is_available(): + torch.xpu.manual_seed_all(seed) os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" @@ -102,8 +105,8 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: random.seed(stdlib_seed) -def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: - r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" +def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) -> Dict[str, Any]: + """Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" states = { "torch": torch.get_rng_state(), "numpy": np.random.get_state(), @@ -111,6 +114,8 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: } if include_cuda: states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [] + if include_xpu and _lightning_xpu_available(): + states["torch.xpu"] = torch.xpu.get_rng_state_all() if torch.xpu.is_available() else [] return states @@ -121,6 +126,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: # torch.cuda rng_state is only included since v1.8. if "torch.cuda" in rng_state_dict: torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) + if "torch.xpu" in rng_state_dict and _lightning_xpu_available() and torch.xpu.is_available(): + torch.xpu.set_rng_state_all(rng_state_dict["torch.xpu"]) np.random.set_state(rng_state_dict["numpy"]) version, state, gauss = rng_state_dict["python"] python_set_rng_state((version, tuple(state), gauss)) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9031b6ee177f3..1823116c438e3 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -189,10 +189,17 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") - # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() - with ctx: + ctx = None + if self.root_device.type == "cuda": + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + if self.root_device.type == "xpu": + ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext() + if ctx is None: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) + else: + with ctx: + return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) def setup_distributed(self) -> None: log.debug(f"{self.__class__.__name__}: setting up distributed...") diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 382f8070898f8..f7955f2208012 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -243,7 +243,8 @@ def __init__( contiguous_memory_optimization: Copies partitioned activations so that they are contiguous in memory. Not supported by all models. - synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary. + synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` or :func:`torch.xpu.synchronize` + at each checkpoint boundary. load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index aa96da63adb65..ff7bdc03a5711 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -30,6 +30,7 @@ import lightning.pytorch as pl from lightning.fabric.strategies.launchers.multiprocessing import ( _check_bad_cuda_fork, + _check_bad_xpu_fork, _check_missing_main_guard, _disable_module_memory_sharing, ) @@ -41,8 +42,12 @@ from lightning.pytorch.strategies.launchers.launcher import _Launcher from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM from lightning.pytorch.trainer.states import TrainerFn, TrainerState +from lightning.pytorch.utilities.imports import _lightning_xpu_available from lightning.pytorch.utilities.rank_zero import rank_zero_debug +if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + log = logging.getLogger(__name__) @@ -108,6 +113,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] """ if self._start_method in ("fork", "forkserver"): _check_bad_cuda_fork() + if XPUAccelerator.is_available(): + _check_bad_xpu_fork() if self._start_method == "spawn": _check_missing_main_guard() if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING: diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 1c97a223b129e..9c87f8a8dd2e5 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -63,7 +63,7 @@ ) from lightning.pytorch.strategies.ddp import _DDP_FORK_ALIASES from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _habana_available_and_importable +from lightning.pytorch.utilities.imports import _habana_available_and_importable, _lightning_xpu_available from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn log = logging.getLogger(__name__) @@ -337,6 +337,11 @@ def _choose_auto_accelerator(self) -> str: if HPUAccelerator.is_available(): return "hpu" + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" if MPSAccelerator.is_available(): return "mps" if CUDAAccelerator.is_available(): @@ -349,6 +354,11 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + if XPUAccelerator.is_available(): + return "xpu" raise MisconfigurationException("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -428,6 +438,12 @@ def _choose_strategy(self) -> Union[Strategy, str]: " in https://github.com/Lightning-AI/lightning-Habana/." ) + if self._accelerator_flag == "xpu" and not _lightning_xpu_available(): + raise ImportError( + "You have asked for XPU but you miss install related integration." + " Please run `pip install lightning-xpu` or see for further instructions" + " in https://github.com/Lightning-AI/lightning-XPU/." + ) if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator): if self._parallel_devices and len(self._parallel_devices) > 1: return XLAStrategy.strategy_name @@ -436,8 +452,16 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: - if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + accelerator_flags_obj = (CUDAAccelerator, MPSAccelerator) + accelerator_flags_str = ("cuda", "gpu", "mps") + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + if XPUAccelerator.is_available(): + accelerator_flags_obj += (XPUAccelerator,) + accelerator_flags_str += ("xpu",) + if isinstance(self._accelerator_flag, accelerator_flags_obj) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in accelerator_flags_str ): device = _determine_root_gpu_device(self._parallel_devices) else: @@ -518,7 +542,7 @@ def _check_and_init_precision(self) -> Precision: rank_zero_info( f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = "cpu" if self._accelerator_flag == "cpu" else "xpu" if _lightning_xpu_available() else "cuda" return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set") @@ -663,3 +687,10 @@ def _register_external_accelerators_and_strategies() -> None: HPUParallelStrategy.register_strategies(StrategyRegistry) if "hpu_single" not in StrategyRegistry: SingleHPUStrategy.register_strategies(StrategyRegistry) + + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + # TODO: Prevent registering multiple times + if "xpu" not in AcceleratorRegistry: + XPUAccelerator.register_accelerators(AcceleratorRegistry) diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 00b546b252ac8..d08a47aacdad9 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -28,7 +28,7 @@ XLAProfiler, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _habana_available_and_importable +from lightning.pytorch.utilities.imports import _habana_available_and_importable, _lightning_xpu_available from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn @@ -168,11 +168,24 @@ def _log_device_info(trainer: "pl.Trainer") -> None: hpu_available = False rank_zero_info(f"HPU available: {hpu_available}, using: {num_hpus} HPUs") + if _lightning_xpu_available(): + from lightning_xpu.pytorch import XPUAccelerator + + num_xpus = trainer.num_devices if isinstance(trainer.accelerator, XPUAccelerator) else 0 + xpu_available = XPUAccelerator.is_available() + else: + num_xpus = 0 + xpu_available = False + rank_zero_info(f"XPU available: {xpu_available}, using: {num_xpus} XPUs") + if ( CUDAAccelerator.is_available() and not isinstance(trainer.accelerator, CUDAAccelerator) or MPSAccelerator.is_available() and not isinstance(trainer.accelerator, MPSAccelerator) + or _lightning_xpu_available() + and XPUAccelerator.is_available() + and not isinstance(trainer.accelerator, XPUAccelerator) ): rank_zero_warn( "GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.", diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index bf7d47a880da3..3be256c6ca873 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -151,7 +151,7 @@ def __init__( precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). - Can be used on CPU, GPU, TPUs, or HPUs. + Can be used on CPU, GPU, TPUs, HPUs or XPUs. Default: ``'32-true'``. logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 6c0815a6af9dc..5bd32c218e40f 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -48,3 +48,12 @@ def _habana_available_and_importable() -> bool: # This is defined as a function instead of a constant to avoid circular imports, because `lightning_habana` # also imports Lightning return bool(_LIGHTNING_HABANA_AVAILABLE) and _try_import_module("lightning_habana") + + +_LIGHTNING_XPU_AVAILABLE = RequirementCache("lightning-xpu") + + +def _lightning_xpu_available() -> bool: + # This is defined as a function instead of a constant to avoid circular imports, because `lightning_xpu` + # also imports Lightning + return bool(_LIGHTNING_XPU_AVAILABLE) and _try_import_module("lightning_xpu") diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py index 4ba9e7f0f960f..4ae40770d384f 100644 --- a/src/lightning/pytorch/utilities/seed.py +++ b/src/lightning/pytorch/utilities/seed.py @@ -20,7 +20,7 @@ @contextmanager -def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]: +def isolate_rng(include_cuda: bool = True, include_xpu: bool = True) -> Generator[None, None, None]: """A context manager that resets the global random state on exit to what it was before entering. It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators. @@ -41,6 +41,6 @@ def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]: tensor([0.7576]) """ - states = _collect_rng_states(include_cuda) + states = _collect_rng_states(include_cuda, include_xpu) yield _set_rng_states(states) diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index a9200baa273dd..9eb7a90075966 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -62,7 +62,7 @@ Lightning forces the following structure to your code which makes it reusable an - Non-essential research code (logging, etc... this goes in Callbacks). - Data (use PyTorch DataLoaders or organize them into a LightningDataModule). -Once you do this, you can train on multiple-GPUs, TPUs, CPUs, HPUs and even in 16-bit precision without changing your code! +Once you do this, you can train on multiple-GPUs, TPUs, CPUs, HPUs, XPUs and even in 16-bit precision without changing your code! [Get started in just 15 minutes](https://lightning.ai/docs/pytorch/latest/starter/introduction.html)