Skip to content

Commit 17b25ad

Browse files
committed
device type helper
1 parent c7e59dd commit 17b25ad

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

src/accelerate/launchers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
PrepareForLaunch,
2525
are_libraries_initialized,
2626
check_cuda_p2p_ib_support,
27+
get_current_device_type,
2728
get_gpu_info,
2829
is_mps_available,
2930
is_torch_version,
@@ -203,8 +204,8 @@ def train(*args):
203204
# process here (the other ones will be set be the launcher).
204205
with patch_environment(**patched_env):
205206
# First dummy launch
206-
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
207-
distributed_type = "MULTI_XPU" if device_type == "xpu" else "MULTI_GPU"
207+
# Determine device type without initializing any device (which would break fork)
208+
device_type, distributed_type = get_current_device_type()
208209
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
209210
launcher = PrepareForLaunch(test_launch, distributed_type=distributed_type)
210211
try:

src/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
clear_environment,
7676
convert_dict_to_env_variables,
7777
get_cpu_distributed_information,
78+
get_current_device_type,
7879
get_gpu_info,
7980
get_int_from_env,
8081
parse_choice_from_env,

src/accelerate/utils/environment.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,49 @@ def are_libraries_initialized(*library_names: str) -> list[str]:
9898
return [lib_name for lib_name in library_names if lib_name in sys.modules.keys()]
9999

100100

101+
def get_current_device_type() -> tuple[str, str]:
102+
"""
103+
Determines the current device type and distributed type without initializing any device.
104+
105+
This is particularly important when using fork-based multiprocessing, as device initialization
106+
before forking can cause errors.
107+
108+
The device detection order follows the same priority as state.py:_prepare_backend():
109+
MLU -> SDAA -> MUSA -> NPU -> HPU -> CUDA -> XPU
110+
111+
Returns:
112+
tuple[str, str]: A tuple of (device_type, distributed_type)
113+
- device_type: The device string (e.g., "cuda", "npu", "xpu")
114+
- distributed_type: The distributed type string (e.g., "MULTI_GPU", "MULTI_NPU")
115+
116+
Example:
117+
```python
118+
>>> device_type, distributed_type = get_current_device_type()
119+
>>> print(device_type) # "cuda"
120+
>>> print(distributed_type) # "MULTI_GPU"
121+
```
122+
"""
123+
from .imports import is_hpu_available, is_mlu_available, is_musa_available, is_npu_available, is_sdaa_available, is_xpu_available
124+
125+
if is_mlu_available():
126+
return "mlu", "MULTI_MLU"
127+
elif is_sdaa_available():
128+
return "sdaa", "MULTI_SDAA"
129+
elif is_musa_available():
130+
return "musa", "MULTI_MUSA"
131+
elif is_npu_available():
132+
return "npu", "MULTI_NPU"
133+
elif is_hpu_available():
134+
return "hpu", "MULTI_HPU"
135+
elif torch.cuda.is_available():
136+
return "cuda", "MULTI_GPU"
137+
elif is_xpu_available():
138+
return "xpu", "MULTI_XPU"
139+
else:
140+
# Default to CUDA even if not available (for CPU-only scenarios where CUDA code paths are still used)
141+
return "cuda", "MULTI_GPU"
142+
143+
101144
def _nvidia_smi():
102145
"""
103146
Returns the right nvidia-smi command based on the system.
@@ -248,7 +291,7 @@ def override_numa_affinity(local_process_index: int, verbose: Optional[bool] = N
248291

249292
if not is_pynvml_available():
250293
raise ImportError(
251-
"To set CPU affinity on CUDA GPUs the `pynvml` package must be available. (`pip install pynvml`)"
294+
"To set CPU affinity on CUDA GPUs the `nvidia-ml-py` package must be available. (`pip install nvidia-ml-py`)"
252295
)
253296
import pynvml as nvml
254297

0 commit comments

Comments
 (0)