diff --git a/src/lightning/pytorch/accelerators/__init__.py b/src/lightning/pytorch/accelerators/__init__.py index 4cadee51f64c7..dc2e438d35b6b 100644 --- a/src/lightning/pytorch/accelerators/__init__.py +++ b/src/lightning/pytorch/accelerators/__init__.py @@ -19,6 +19,7 @@ from lightning.pytorch.accelerators.cpu import CPUAccelerator # noqa: F401 from lightning.pytorch.accelerators.cuda import CUDAAccelerator # noqa: F401 from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401 +from lightning.pytorch.accelerators.npu import NPUAccelerator # noqa: F401 from lightning.pytorch.accelerators.xla import XLAAccelerator # noqa: F401 AcceleratorRegistry = _AcceleratorRegistry() diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 0490c2d86431c..fd881c17e58d7 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC +from contextlib import nullcontext from typing import Any, Dict import lightning.pytorch as pl @@ -45,3 +46,9 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """ raise NotImplementedError + + def get_distribute_name(self) -> str: + return "gloo" + + def get_stream_context(self, device_id: Any) -> Any: + return nullcontext() diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 6df3bc6b468ee..b7fc8e81e3a54 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -15,6 +15,7 @@ import os import shutil import subprocess +from contextlib import nullcontext from typing import Any, Dict, List, Optional, Union import torch @@ -104,6 +105,14 @@ def auto_device_count() -> int: def is_available() -> bool: return num_cuda_devices() > 0 + @override + def get_distribute_name(self) -> str: + return "nccl" + + @override + def get_stream_context(self, device_id: List[int]) -> Any: + return torch.cuda.stream(torch.cuda.Stream()) if device_id is not None else nullcontext() + @classmethod @override def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: diff --git a/src/lightning/pytorch/accelerators/npu.py b/src/lightning/pytorch/accelerators/npu.py new file mode 100644 index 0000000000000..52fc5eba9c971 --- /dev/null +++ b/src/lightning/pytorch/accelerators/npu.py @@ -0,0 +1,112 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext +from typing import Any, Dict, List, Optional, Union + +import torch +from typing_extensions import override + +from lightning.fabric.accelerators import _AcceleratorRegistry +from lightning.fabric.utilities.types import _DEVICE +from lightning.pytorch.accelerators.accelerator import Accelerator +from lightning.pytorch.utilities.exceptions import MisconfigurationException + + +class NPUAccelerator(Accelerator): + """Accelerator for Ascend NPU devices.""" + + @override + def setup_device(self, device: torch.device) -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not NPU. + """ + if device.type != "npu": + raise MisconfigurationException(f"Device should be NPU, got {device} instead.") + torch.npu.set_device(device) + + @override + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + return torch.npu.memory_stats(device) + + @override + def teardown(self) -> None: + torch.npu.empty_cache() + + @staticmethod + @override + def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + """Accelerator device parsing logic. + + -1 or '-1' means use all npus. + + """ + + if isinstance(devices, list): + return devices + if isinstance(devices, str): + if devices == "-1": + return list(range(torch.npu.device_count())) + if "," in devices: + return [int(x.strip()) for x in devices.split(",") if len(x) > 0] + return list(range(int(devices.strip()))) + if isinstance(devices, int): + if devices == -1: + return list(range(torch.npu.device_count())) + return list(range(devices)) + + return None + + @staticmethod + @override + def get_parallel_devices(devices: List[int]) -> List[torch.device]: + """Gets parallel devices for the Accelerator.""" + + return [torch.device("npu", i) for i in devices] + + @staticmethod + @override + def auto_device_count() -> int: + """Get the devices when set to auto.""" + + return torch.npu.device_count() + + @staticmethod + @override + def is_available() -> bool: + try: + import torch_npu # noqa: F401 + + return torch.npu.device_count() > 0 + except ImportError: + # NPU may raise these exceptions if it's not properly configured. + return False + + @override + def get_distribute_name(self) -> str: + return "hccl" + + @override + def get_stream_context(self, device_id: List[int]) -> Any: + return torch.npu.stream(torch.npu.Stream()) if device_id is not None else nullcontext() + + @classmethod + @override + def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: + accelerator_registry.register( + "npu", + cls, + description=cls.__name__, + ) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 506e5b1c89283..c1ca20a32414a 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from contextlib import nullcontext from datetime import timedelta from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union @@ -31,7 +30,6 @@ from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.utilities.distributed import ( _distributed_is_initialized, - _get_default_process_group_backend_for_device, _init_dist_connection, _sync_ddp_if_available, ) @@ -193,7 +191,8 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + assert self.accelerator is not None + ctx = self.accelerator.get_stream_context(device_ids) with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) @@ -206,7 +205,8 @@ def setup_distributed(self) -> None: _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) def _get_process_group_backend(self) -> str: - return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) + assert self.accelerator is not None + return self._process_group_backend or self.accelerator.get_distribute_name() def set_world_ranks(self) -> None: if self.cluster_environment is not None: diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4e7a3bb122a55..3900d9461a499 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -39,6 +39,7 @@ from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau from lightning.pytorch.accelerators.cuda import CUDAAccelerator +from lightning.pytorch.accelerators.npu import NPUAccelerator from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers from lightning.pytorch.plugins.precision import Precision from lightning.pytorch.strategies.ddp import DDPStrategy @@ -315,10 +316,10 @@ def __init__( @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + if not isinstance(self.accelerator, (CUDAAccelerator, NPUAccelerator)): raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + "The DeepSpeed strategy is only supported on CUDA GPUs or Ascend NPUs but" + " `{self.accelerator.__class__.__name__}` is used." ) super().setup_environment() diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index b339838920d19..e577faf4ebf96 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -47,7 +47,6 @@ ) from lightning.fabric.utilities.distributed import ( _distributed_is_initialized, - _get_default_process_group_backend_for_device, _init_dist_connection, _sync_ddp_if_available, ) @@ -261,7 +260,8 @@ def setup_environment(self) -> None: _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) def _get_process_group_backend(self) -> str: - return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) + assert self.accelerator is not None + return self._process_group_backend or self.accelerator.get_distribute_name() def set_world_ranks(self) -> None: if self.cluster_environment is not None: diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index a3057351ca348..318034066a4c1 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -34,6 +34,7 @@ from lightning.pytorch.accelerators.accelerator import Accelerator from lightning.pytorch.accelerators.cuda import CUDAAccelerator from lightning.pytorch.accelerators.mps import MPSAccelerator +from lightning.pytorch.accelerators.npu import NPUAccelerator from lightning.pytorch.accelerators.xla import XLAAccelerator from lightning.pytorch.plugins import ( _PLUGIN_INPUT, @@ -355,6 +356,8 @@ def _choose_auto_accelerator(self) -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if NPUAccelerator.is_available(): + return "npu" return "cpu" @staticmethod @@ -462,7 +465,7 @@ def _choose_strategy(self) -> Union[Strategy, str]: return "ddp" if len(self._parallel_devices) <= 1: if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "npu") ): device = _determine_root_gpu_device(self._parallel_devices) else: @@ -482,9 +485,9 @@ def _check_strategy_and_fallback(self) -> None: if ( strategy_flag in FSDPStrategy.get_registered_strategies() or isinstance(self._strategy_flag, FSDPStrategy) - ) and self._accelerator_flag not in ("cuda", "gpu"): + ) and self._accelerator_flag not in ("cuda", "gpu", "npu"): raise MisconfigurationException( - f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used." + f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU nor NPU accelerator is not used." ) if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): raise ValueError( diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 2dd5af675a383..d61dedfd72d4d 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -17,7 +17,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator +from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, NPUAccelerator, XLAAccelerator from lightning.pytorch.loggers.logger import DummyLogger from lightning.pytorch.profilers import ( AdvancedProfiler, @@ -178,6 +178,9 @@ def _log_device_info(trainer: "pl.Trainer") -> None: hpu_available = False rank_zero_info(f"HPU available: {hpu_available}, using: {num_hpus} HPUs") + number_npu_cores = trainer.num_devices if isinstance(trainer.accelerator, NPUAccelerator) else 0 + rank_zero_info(f"NPU available: {NPUAccelerator.is_available()}, using: {number_npu_cores} NPU cores") + if ( CUDAAccelerator.is_available() and not isinstance(trainer.accelerator, CUDAAccelerator) @@ -203,3 +206,6 @@ def _log_device_info(trainer: "pl.Trainer") -> None: if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator): rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.") + + if NPUAccelerator.is_available() and not isinstance(trainer.accelerator, NPUAccelerator): + rank_zero_warn("NPU available but not used. You can set it by doing `Trainer(accelerator='npu')`.")