Skip to content

Commit 99662c5

Browse files
committed
support Musa device
1 parent 25c9922 commit 99662c5

File tree

13 files changed

+355
-20
lines changed

13 files changed

+355
-20
lines changed

docs/source-fabric/api/accelerators.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ Accelerators
2020
CUDAAccelerator
2121
MPSAccelerator
2222
XLAAccelerator
23+
MUSAAccelerator

docs/source-pytorch/api_references.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ accelerators
1414
CPUAccelerator
1515
CUDAAccelerator
1616
XLAAccelerator
17+
MUSAAccelerator
1718

1819
callbacks
1920
---------

docs/source-pytorch/extensions/accelerator.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,4 @@ Accelerator API
128128
CUDAAccelerator
129129
MPSAccelerator
130130
XLAAccelerator
131+
MUSAAccelerator

src/lightning/fabric/accelerators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
1919
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
2020
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401
21+
from lightning.fabric.accelerators.musa import MUSAAccelerator # noqa: F401
2122
from lightning.fabric.utilities.registry import _register_classes
2223

2324
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from functools import lru_cache
15+
from typing import Optional, Union
16+
17+
import torch
18+
from typing_extensions import override
19+
20+
from lightning.fabric.accelerators.accelerator import Accelerator
21+
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
22+
from lightning.fabric.utilities.rank_zero import rank_zero_info
23+
24+
25+
class MUSAAccelerator(Accelerator):
26+
"""Accelerator for MUSA devices."""
27+
28+
@override
29+
def setup_device(self, device: torch.device) -> None:
30+
"""
31+
Raises:
32+
ValueError:
33+
If the selected device is not of type MUSA.
34+
"""
35+
if device.type != "musa":
36+
raise ValueError(f"Device should be MUSA, got {device} instead.")
37+
_check_musa_matmul_precision(device)
38+
torch.musa.set_device(device)
39+
40+
@override
41+
def teardown(self) -> None:
42+
_clear_musa_memory()
43+
44+
@staticmethod
45+
@override
46+
def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]:
47+
"""Accelerator device parsing logic."""
48+
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
49+
50+
return _parse_gpu_ids(devices, include_musa=True)
51+
52+
@staticmethod
53+
@override
54+
def get_parallel_devices(devices: list[int]) -> list[torch.device]:
55+
"""Gets parallel devices for the Accelerator."""
56+
return [torch.device("musa", i) for i in devices]
57+
58+
@staticmethod
59+
@override
60+
def auto_device_count() -> int:
61+
"""Get the devices when set to auto."""
62+
return num_musa_devices()
63+
64+
@staticmethod
65+
@override
66+
def is_available() -> bool:
67+
return num_musa_devices() > 0
68+
69+
@staticmethod
70+
@override
71+
def name() -> str:
72+
return "musa"
73+
74+
@classmethod
75+
@override
76+
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
77+
accelerator_registry.register(
78+
cls.name(),
79+
cls,
80+
description=cls.__name__,
81+
)
82+
83+
84+
def find_usable_musa_devices(num_devices: int = -1) -> list[int]:
85+
"""Returns a list of all available and usable MUSA GPU devices.
86+
87+
A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function
88+
tests for each GPU on the system until the target number of usable devices is found.
89+
90+
A subset of GPUs on the system might be used by other processes, and if the GPU is configured to operate in
91+
'exclusive' mode (configurable by the admin), then only one process is allowed to occupy it.
92+
93+
Args:
94+
num_devices: The number of devices you want to request. By default, this function will return as many as there
95+
are usable MUSA GPU devices available.
96+
97+
Warning:
98+
If multiple processes call this function at the same time, there can be race conditions in the case where
99+
both processes determine that the device is unoccupied, leading into one of them crashing later on.
100+
101+
"""
102+
if num_devices == 0:
103+
return []
104+
visible_devices = _get_all_visible_musa_devices()
105+
if not visible_devices:
106+
raise ValueError(
107+
f"You requested to find {num_devices} devices but there are no visible MUSA devices on this machine."
108+
)
109+
if num_devices > len(visible_devices):
110+
raise ValueError(
111+
f"You requested to find {num_devices} devices but this machine only has {len(visible_devices)} GPUs."
112+
)
113+
114+
available_devices = []
115+
unavailable_devices = []
116+
117+
for gpu_idx in visible_devices:
118+
try:
119+
torch.tensor(0, device=torch.device("musa", gpu_idx))
120+
except RuntimeError:
121+
unavailable_devices.append(gpu_idx)
122+
continue
123+
124+
available_devices.append(gpu_idx)
125+
if len(available_devices) == num_devices:
126+
# exit early if we found the right number of GPUs
127+
break
128+
129+
if num_devices != -1 and len(available_devices) != num_devices:
130+
raise RuntimeError(
131+
f"You requested to find {num_devices} devices but only {len(available_devices)} are currently available."
132+
f" The devices {unavailable_devices} are occupied by other processes and can't be used at the moment."
133+
)
134+
return available_devices
135+
136+
137+
def _get_all_visible_musa_devices() -> list[int]:
138+
"""Returns a list of all visible MUSA GPU devices.
139+
140+
Devices masked by the environment variabale ``MUSA_VISIBLE_DEVICES`` won't be returned here. For example, assume you
141+
have 8 physical GPUs. If ``MUSA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]``
142+
because these are the three visible GPUs after applying the mask ``MUSA_VISIBLE_DEVICES``.
143+
144+
"""
145+
return list(range(num_musa_devices()))
146+
147+
148+
def num_musa_devices() -> int:
149+
"""Returns the number of available MUSA devices."""
150+
return torch.musa.device_count()
151+
152+
153+
def is_musa_available() -> bool:
154+
"""Returns a bool indicating if MUSA is currently available."""
155+
# We set `PYTORCH_NVML_BASED_MUSA_CHECK=1` in lightning.fabric.__init__.py
156+
return torch.musa.is_available()
157+
158+
159+
def _is_ampere_or_later(device: Optional[torch.device] = None) -> bool:
160+
major, _ = torch.musa.get_device_capability(device)
161+
return major >= 8 # Ampere and later leverage tensor cores, where this setting becomes useful
162+
163+
164+
@lru_cache(1) # show the warning only ever once
165+
def _check_musa_matmul_precision(device: torch.device) -> None:
166+
if not torch.musa.is_available() or not _is_ampere_or_later(device):
167+
return
168+
# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and
169+
# `set_float32_matmul_precision`
170+
if torch.get_float32_matmul_precision() == "highest": # default
171+
rank_zero_info(
172+
f"You are using a MUSA device ({torch.musa.get_device_name(device)!r}) that has Tensor Cores. To properly"
173+
" utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off"
174+
" precision for performance. For more details, read https://pytorch.org/docs/stable/generated/"
175+
"torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision"
176+
)
177+
# note: no need change `torch.backends.cudnn.allow_tf32` as it's enabled by default:
178+
# https://pytorch.org/docs/stable/notes/musa.html#tensorfloat-32-tf32-on-ampere-devices
179+
180+
181+
def _clear_musa_memory() -> None:
182+
# strangely, the attribute function be undefined when torch.compile is used
183+
if hasattr(torch._C, "_musa_clearCublasWorkspaces"):
184+
# https://github.com/pytorch/pytorch/issues/95668
185+
torch._C._musa_clearMublasWorkspaces()
186+
torch.musa.empty_cache()

src/lightning/fabric/cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from lightning_utilities.core.imports import RequirementCache
2222
from typing_extensions import get_args
2323

24-
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
24+
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator, MUSAAccelerator
2525
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
2626
from lightning.fabric.strategies import STRATEGY_REGISTRY
2727
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
@@ -196,9 +196,11 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
196196
else:
197197
raise ValueError(f"Cannot default to '1' device for accelerator='{accelerator}'")
198198
if accelerator == "gpu":
199-
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
199+
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True, include_musa=True)
200200
elif accelerator == "cuda":
201201
parsed_devices = CUDAAccelerator.parse_devices(devices)
202+
elif accelerator == "musa":
203+
parsed_devices = MUSAAccelerator.parse_devices(devices)
202204
elif accelerator == "mps":
203205
parsed_devices = MPSAccelerator.parse_devices(devices)
204206
elif accelerator == "tpu":

src/lightning/fabric/connector.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightning.fabric.accelerators.cuda import CUDAAccelerator
2525
from lightning.fabric.accelerators.mps import MPSAccelerator
2626
from lightning.fabric.accelerators.xla import XLAAccelerator
27+
from lightning.fabric.accelerators.musa import MUSAAccelerator
2728
from lightning.fabric.plugins import (
2829
BitsandbytesPrecision,
2930
CheckpointIO,
@@ -322,6 +323,8 @@ def _choose_auto_accelerator() -> str:
322323
return "mps"
323324
if CUDAAccelerator.is_available():
324325
return "cuda"
326+
if MUSAAccelerator.is_available():
327+
return "musa"
325328
return "cpu"
326329

327330
@staticmethod
@@ -330,6 +333,8 @@ def _choose_gpu_accelerator_backend() -> str:
330333
return "mps"
331334
if CUDAAccelerator.is_available():
332335
return "cuda"
336+
if MUSAAccelerator.is_available():
337+
return "musa"
333338
raise RuntimeError("No supported gpu backend found!")
334339

335340
def _set_parallel_devices_and_init_accelerator(self) -> None:
@@ -400,8 +405,8 @@ def _choose_strategy(self) -> Union[Strategy, str]:
400405
if self._num_nodes_flag > 1:
401406
return "ddp"
402407
if len(self._parallel_devices) <= 1:
403-
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
404-
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
408+
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator, MUSAAccelerator)) or (
409+
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "musa")
405410
):
406411
device = _determine_root_gpu_device(self._parallel_devices)
407412
else:

src/lightning/fabric/utilities/device_parser.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _parse_gpu_ids(
5050
gpus: Optional[Union[int, str, list[int]]],
5151
include_cuda: bool = False,
5252
include_mps: bool = False,
53+
include_musa: bool = False,
5354
) -> Optional[list[int]]:
5455
"""Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`.
5556
@@ -61,6 +62,7 @@ def _parse_gpu_ids(
6162
Any int N > 0 indicates that GPUs [0..N) should be used.
6263
include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing.
6364
include_mps: A boolean value indicating whether to include MPS devices for GPU parsing.
65+
include_musa: A boolean value indicating whether to include MUSA devices for GPU parsing.
6466
6567
Returns:
6668
A list of GPUs to be used or ``None`` if no GPUs were requested
@@ -70,7 +72,7 @@ def _parse_gpu_ids(
7072
If no GPUs are available but the value of gpus variable indicates request for GPUs
7173
7274
.. note::
73-
``include_cuda`` and ``include_mps`` default to ``False`` so that you only
75+
``include_cuda`` ``include_musa`` and ``include_mps`` default to ``False`` so that you only
7476
have to specify which device type to use and all other devices are not disabled.
7577
7678
"""
@@ -84,23 +86,23 @@ def _parse_gpu_ids(
8486
# We know the user requested GPUs therefore if some of the
8587
# requested GPUs are not available an exception is thrown.
8688
gpus = _normalize_parse_gpu_string_input(gpus)
87-
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps)
89+
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)
8890
if not gpus:
8991
raise MisconfigurationException("GPUs requested but none are available.")
9092

9193
if (
9294
torch.distributed.is_available()
9395
and torch.distributed.is_torchelastic_launched()
9496
and len(gpus) != 1
95-
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1
97+
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)) == 1
9698
):
9799
# Omit sanity check on torchelastic because by default it shows one visible GPU per process
98100
return gpus
99101

100102
# Check that GPUs are unique. Duplicate GPUs are not supported by the backend.
101103
_check_unique(gpus)
102104

103-
return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps)
105+
return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)
104106

105107

106108
def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[int, list[int]]:
@@ -113,7 +115,7 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[in
113115
return int(s.strip())
114116

115117

116-
def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: bool = False) -> list[int]:
118+
def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: bool = False, include_musa: bool = False) -> list[int]:
117119
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the
118120
GPUs is not available.
119121
@@ -128,9 +130,9 @@ def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps:
128130
If machine has fewer available GPUs than requested.
129131
130132
"""
131-
if sum((include_cuda, include_mps)) == 0:
133+
if sum((include_cuda, include_mps, include_musa)) == 0:
132134
raise ValueError("At least one gpu type should be specified!")
133-
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
135+
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)
134136
for gpu in gpus:
135137
if gpu not in all_available_gpus:
136138
raise MisconfigurationException(
@@ -140,7 +142,7 @@ def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps:
140142

141143

142144
def _normalize_parse_gpu_input_to_list(
143-
gpus: Union[int, list[int], tuple[int, ...]], include_cuda: bool, include_mps: bool
145+
gpus: Union[int, list[int], tuple[int, ...]], include_cuda: bool, include_mps: bool, include_musa: bool
144146
) -> Optional[list[int]]:
145147
assert gpus is not None
146148
if isinstance(gpus, (MutableSequence, tuple)):
@@ -150,22 +152,24 @@ def _normalize_parse_gpu_input_to_list(
150152
if not gpus: # gpus==0
151153
return None
152154
if gpus == -1:
153-
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)
155+
return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)
154156

155157
return list(range(gpus))
156158

157159

158-
def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> list[int]:
160+
def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False, include_musa: bool = False) -> list[int]:
159161
"""
160162
Returns:
161163
A list of all available GPUs
162164
"""
163165
from lightning.fabric.accelerators.cuda import _get_all_visible_cuda_devices
164166
from lightning.fabric.accelerators.mps import _get_all_available_mps_gpus
167+
from lightning.fabric.accelerators.musa import _get_all_visible_musa_devices
165168

166169
cuda_gpus = _get_all_visible_cuda_devices() if include_cuda else []
167170
mps_gpus = _get_all_available_mps_gpus() if include_mps else []
168-
return cuda_gpus + mps_gpus
171+
musa_gpus = _get_all_visible_musa_devices() if include_musa else []
172+
return cuda_gpus + mps_gpus + musa_gpus + musa_gpus
169173

170174

171175
def _check_unique(device_ids: list[int]) -> None:
@@ -211,11 +215,14 @@ def _select_auto_accelerator() -> str:
211215
from lightning.fabric.accelerators.cuda import CUDAAccelerator
212216
from lightning.fabric.accelerators.mps import MPSAccelerator
213217
from lightning.fabric.accelerators.xla import XLAAccelerator
218+
from lightning.fabric.accelerators.musa import MUSAAccelerator
214219

215220
if XLAAccelerator.is_available():
216221
return "tpu"
217222
if MPSAccelerator.is_available():
218223
return "mps"
219224
if CUDAAccelerator.is_available():
220225
return "cuda"
226+
if MUSAAccelerator.is_available():
227+
return "musa"
221228
return "cpu"

0 commit comments

Comments
 (0)