|
| 1 | +# Copyright The Lightning AI team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from functools import lru_cache |
| 15 | +from typing import Optional, Union |
| 16 | + |
| 17 | +import torch |
| 18 | +from typing_extensions import override |
| 19 | + |
| 20 | +from lightning.fabric.accelerators.accelerator import Accelerator |
| 21 | +from lightning.fabric.accelerators.registry import _AcceleratorRegistry |
| 22 | +from lightning.fabric.utilities.rank_zero import rank_zero_info |
| 23 | + |
| 24 | + |
| 25 | +class MUSAAccelerator(Accelerator): |
| 26 | + """Accelerator for MUSA devices.""" |
| 27 | + |
| 28 | + @override |
| 29 | + def setup_device(self, device: torch.device) -> None: |
| 30 | + """ |
| 31 | + Raises: |
| 32 | + ValueError: |
| 33 | + If the selected device is not of type MUSA. |
| 34 | + """ |
| 35 | + if device.type != "musa": |
| 36 | + raise ValueError(f"Device should be MUSA, got {device} instead.") |
| 37 | + _check_musa_matmul_precision(device) |
| 38 | + torch.musa.set_device(device) |
| 39 | + |
| 40 | + @override |
| 41 | + def teardown(self) -> None: |
| 42 | + _clear_musa_memory() |
| 43 | + |
| 44 | + @staticmethod |
| 45 | + @override |
| 46 | + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: |
| 47 | + """Accelerator device parsing logic.""" |
| 48 | + from lightning.fabric.utilities.device_parser import _parse_gpu_ids |
| 49 | + |
| 50 | + return _parse_gpu_ids(devices, include_musa=True) |
| 51 | + |
| 52 | + @staticmethod |
| 53 | + @override |
| 54 | + def get_parallel_devices(devices: list[int]) -> list[torch.device]: |
| 55 | + """Gets parallel devices for the Accelerator.""" |
| 56 | + return [torch.device("musa", i) for i in devices] |
| 57 | + |
| 58 | + @staticmethod |
| 59 | + @override |
| 60 | + def auto_device_count() -> int: |
| 61 | + """Get the devices when set to auto.""" |
| 62 | + return num_musa_devices() |
| 63 | + |
| 64 | + @staticmethod |
| 65 | + @override |
| 66 | + def is_available() -> bool: |
| 67 | + return num_musa_devices() > 0 |
| 68 | + |
| 69 | + @staticmethod |
| 70 | + @override |
| 71 | + def name() -> str: |
| 72 | + return "musa" |
| 73 | + |
| 74 | + @classmethod |
| 75 | + @override |
| 76 | + def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: |
| 77 | + accelerator_registry.register( |
| 78 | + cls.name(), |
| 79 | + cls, |
| 80 | + description=cls.__name__, |
| 81 | + ) |
| 82 | + |
| 83 | + |
| 84 | +def find_usable_musa_devices(num_devices: int = -1) -> list[int]: |
| 85 | + """Returns a list of all available and usable MUSA GPU devices. |
| 86 | +
|
| 87 | + A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function |
| 88 | + tests for each GPU on the system until the target number of usable devices is found. |
| 89 | +
|
| 90 | + A subset of GPUs on the system might be used by other processes, and if the GPU is configured to operate in |
| 91 | + 'exclusive' mode (configurable by the admin), then only one process is allowed to occupy it. |
| 92 | +
|
| 93 | + Args: |
| 94 | + num_devices: The number of devices you want to request. By default, this function will return as many as there |
| 95 | + are usable MUSA GPU devices available. |
| 96 | +
|
| 97 | + Warning: |
| 98 | + If multiple processes call this function at the same time, there can be race conditions in the case where |
| 99 | + both processes determine that the device is unoccupied, leading into one of them crashing later on. |
| 100 | +
|
| 101 | + """ |
| 102 | + if num_devices == 0: |
| 103 | + return [] |
| 104 | + visible_devices = _get_all_visible_musa_devices() |
| 105 | + if not visible_devices: |
| 106 | + raise ValueError( |
| 107 | + f"You requested to find {num_devices} devices but there are no visible MUSA devices on this machine." |
| 108 | + ) |
| 109 | + if num_devices > len(visible_devices): |
| 110 | + raise ValueError( |
| 111 | + f"You requested to find {num_devices} devices but this machine only has {len(visible_devices)} GPUs." |
| 112 | + ) |
| 113 | + |
| 114 | + available_devices = [] |
| 115 | + unavailable_devices = [] |
| 116 | + |
| 117 | + for gpu_idx in visible_devices: |
| 118 | + try: |
| 119 | + torch.tensor(0, device=torch.device("musa", gpu_idx)) |
| 120 | + except RuntimeError: |
| 121 | + unavailable_devices.append(gpu_idx) |
| 122 | + continue |
| 123 | + |
| 124 | + available_devices.append(gpu_idx) |
| 125 | + if len(available_devices) == num_devices: |
| 126 | + # exit early if we found the right number of GPUs |
| 127 | + break |
| 128 | + |
| 129 | + if num_devices != -1 and len(available_devices) != num_devices: |
| 130 | + raise RuntimeError( |
| 131 | + f"You requested to find {num_devices} devices but only {len(available_devices)} are currently available." |
| 132 | + f" The devices {unavailable_devices} are occupied by other processes and can't be used at the moment." |
| 133 | + ) |
| 134 | + return available_devices |
| 135 | + |
| 136 | + |
| 137 | +def _get_all_visible_musa_devices() -> list[int]: |
| 138 | + """Returns a list of all visible MUSA GPU devices. |
| 139 | +
|
| 140 | + Devices masked by the environment variabale ``MUSA_VISIBLE_DEVICES`` won't be returned here. For example, assume you |
| 141 | + have 8 physical GPUs. If ``MUSA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]`` |
| 142 | + because these are the three visible GPUs after applying the mask ``MUSA_VISIBLE_DEVICES``. |
| 143 | +
|
| 144 | + """ |
| 145 | + return list(range(num_musa_devices())) |
| 146 | + |
| 147 | + |
| 148 | +def num_musa_devices() -> int: |
| 149 | + """Returns the number of available MUSA devices.""" |
| 150 | + return torch.musa.device_count() |
| 151 | + |
| 152 | + |
| 153 | +def is_musa_available() -> bool: |
| 154 | + """Returns a bool indicating if MUSA is currently available.""" |
| 155 | + # We set `PYTORCH_NVML_BASED_MUSA_CHECK=1` in lightning.fabric.__init__.py |
| 156 | + return torch.musa.is_available() |
| 157 | + |
| 158 | + |
| 159 | +def _is_ampere_or_later(device: Optional[torch.device] = None) -> bool: |
| 160 | + major, _ = torch.musa.get_device_capability(device) |
| 161 | + return major >= 8 # Ampere and later leverage tensor cores, where this setting becomes useful |
| 162 | + |
| 163 | + |
| 164 | +@lru_cache(1) # show the warning only ever once |
| 165 | +def _check_musa_matmul_precision(device: torch.device) -> None: |
| 166 | + if not torch.musa.is_available() or not _is_ampere_or_later(device): |
| 167 | + return |
| 168 | + # check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and |
| 169 | + # `set_float32_matmul_precision` |
| 170 | + if torch.get_float32_matmul_precision() == "highest": # default |
| 171 | + rank_zero_info( |
| 172 | + f"You are using a MUSA device ({torch.musa.get_device_name(device)!r}) that has Tensor Cores. To properly" |
| 173 | + " utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off" |
| 174 | + " precision for performance. For more details, read https://pytorch.org/docs/stable/generated/" |
| 175 | + "torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision" |
| 176 | + ) |
| 177 | + # note: no need change `torch.backends.cudnn.allow_tf32` as it's enabled by default: |
| 178 | + # https://pytorch.org/docs/stable/notes/musa.html#tensorfloat-32-tf32-on-ampere-devices |
| 179 | + |
| 180 | + |
| 181 | +def _clear_musa_memory() -> None: |
| 182 | + # strangely, the attribute function be undefined when torch.compile is used |
| 183 | + if hasattr(torch._C, "_musa_clearCublasWorkspaces"): |
| 184 | + # https://github.com/pytorch/pytorch/issues/95668 |
| 185 | + torch._C._musa_clearMublasWorkspaces() |
| 186 | + torch.musa.empty_cache() |
0 commit comments