Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/pytorch/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lightning.pytorch.accelerators.cpu import CPUAccelerator # noqa: F401
from lightning.pytorch.accelerators.cuda import CUDAAccelerator # noqa: F401
from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401
from lightning.pytorch.accelerators.npu import NPUAccelerator # noqa: F401
from lightning.pytorch.accelerators.xla import XLAAccelerator # noqa: F401

AcceleratorRegistry = _AcceleratorRegistry()
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/pytorch/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from contextlib import nullcontext
from typing import Any, Dict

import lightning.pytorch as pl
Expand Down Expand Up @@ -45,3 +46,9 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:

"""
raise NotImplementedError

def get_distribute_name(self) -> str:
return "gloo"

def get_stream_context(self, device_id: Any) -> Any:
return nullcontext()
9 changes: 9 additions & 0 deletions src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import shutil
import subprocess
from contextlib import nullcontext
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -104,6 +105,14 @@ def auto_device_count() -> int:
def is_available() -> bool:
return num_cuda_devices() > 0

@override
def get_distribute_name(self) -> str:
return "nccl"

@override
def get_stream_context(self, device_id: List[int]) -> Any:
return torch.cuda.stream(torch.cuda.Stream()) if device_id is not None else nullcontext()

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
Expand Down
112 changes: 112 additions & 0 deletions src/lightning/pytorch/accelerators/npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import Any, Dict, List, Optional, Union

import torch
from typing_extensions import override

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.utilities.exceptions import MisconfigurationException


class NPUAccelerator(Accelerator):
"""Accelerator for Ascend NPU devices."""

@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not NPU.
"""
if device.type != "npu":
raise MisconfigurationException(f"Device should be NPU, got {device} instead.")
torch.npu.set_device(device)

@override
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
return torch.npu.memory_stats(device)

@override
def teardown(self) -> None:
torch.npu.empty_cache()

@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic.

-1 or '-1' means use all npus.

"""

if isinstance(devices, list):
return devices
if isinstance(devices, str):
if devices == "-1":
return list(range(torch.npu.device_count()))
if "," in devices:
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
return list(range(int(devices.strip())))
if isinstance(devices, int):
if devices == -1:
return list(range(torch.npu.device_count()))
return list(range(devices))

return None

@staticmethod
@override
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""

return [torch.device("npu", i) for i in devices]

@staticmethod
@override
def auto_device_count() -> int:
"""Get the devices when set to auto."""

return torch.npu.device_count()

@staticmethod
@override
def is_available() -> bool:
try:
import torch_npu # noqa: F401

return torch.npu.device_count() > 0
except ImportError:
# NPU may raise these exceptions if it's not properly configured.
return False

@override
def get_distribute_name(self) -> str:
return "hccl"

@override
def get_stream_context(self, device_id: List[int]) -> Any:
return torch.npu.stream(torch.npu.Stream()) if device_id is not None else nullcontext()

@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"npu",
cls,
description=cls.__name__,
)
8 changes: 4 additions & 4 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import nullcontext
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union

Expand All @@ -31,7 +30,6 @@
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.utilities.distributed import (
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
Expand Down Expand Up @@ -193,7 +191,8 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
assert self.accelerator is not None
ctx = self.accelerator.get_stream_context(device_ids)
with ctx:
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

Expand All @@ -206,7 +205,8 @@ def setup_distributed(self) -> None:
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
assert self.accelerator is not None
return self._process_group_backend or self.accelerator.get_distribute_name()

def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
Expand Down
7 changes: 4 additions & 3 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
from lightning.pytorch.accelerators.npu import NPUAccelerator
from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers
from lightning.pytorch.plugins.precision import Precision
from lightning.pytorch.strategies.ddp import DDPStrategy
Expand Down Expand Up @@ -315,10 +316,10 @@ def __init__(

@override
def setup_environment(self) -> None:
if not isinstance(self.accelerator, CUDAAccelerator):
if not isinstance(self.accelerator, (CUDAAccelerator, NPUAccelerator)):
raise RuntimeError(
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
" is used."
"The DeepSpeed strategy is only supported on CUDA GPUs or Ascend NPUs but"
" `{self.accelerator.__class__.__name__}` is used."
)
super().setup_environment()

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
)
from lightning.fabric.utilities.distributed import (
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
Expand Down Expand Up @@ -261,7 +260,8 @@ def setup_environment(self) -> None:
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
assert self.accelerator is not None
return self._process_group_backend or self.accelerator.get_distribute_name()

def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
from lightning.pytorch.accelerators.mps import MPSAccelerator
from lightning.pytorch.accelerators.npu import NPUAccelerator
from lightning.pytorch.accelerators.xla import XLAAccelerator
from lightning.pytorch.plugins import (
_PLUGIN_INPUT,
Expand Down Expand Up @@ -355,6 +356,8 @@ def _choose_auto_accelerator(self) -> str:
return "mps"
if CUDAAccelerator.is_available():
return "cuda"
if NPUAccelerator.is_available():
return "npu"
return "cpu"

@staticmethod
Expand Down Expand Up @@ -462,7 +465,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return "ddp"
if len(self._parallel_devices) <= 1:
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "npu")
):
device = _determine_root_gpu_device(self._parallel_devices)
else:
Expand All @@ -482,9 +485,9 @@ def _check_strategy_and_fallback(self) -> None:

if (
strategy_flag in FSDPStrategy.get_registered_strategies() or isinstance(self._strategy_flag, FSDPStrategy)
) and self._accelerator_flag not in ("cuda", "gpu"):
) and self._accelerator_flag not in ("cuda", "gpu", "npu"):
raise MisconfigurationException(
f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used."
f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU nor NPU accelerator is not used."
)
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
raise ValueError(
Expand Down
8 changes: 7 additions & 1 deletion src/lightning/pytorch/trainer/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import lightning.pytorch as pl
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, NPUAccelerator, XLAAccelerator
from lightning.pytorch.loggers.logger import DummyLogger
from lightning.pytorch.profilers import (
AdvancedProfiler,
Expand Down Expand Up @@ -178,6 +178,9 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
hpu_available = False
rank_zero_info(f"HPU available: {hpu_available}, using: {num_hpus} HPUs")

number_npu_cores = trainer.num_devices if isinstance(trainer.accelerator, NPUAccelerator) else 0
rank_zero_info(f"NPU available: {NPUAccelerator.is_available()}, using: {number_npu_cores} NPU cores")

if (
CUDAAccelerator.is_available()
and not isinstance(trainer.accelerator, CUDAAccelerator)
Expand All @@ -203,3 +206,6 @@ def _log_device_info(trainer: "pl.Trainer") -> None:

if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.")

if NPUAccelerator.is_available() and not isinstance(trainer.accelerator, NPUAccelerator):
rank_zero_warn("NPU available but not used. You can set it by doing `Trainer(accelerator='npu')`.")