Skip to content

Commit 205cd61

Browse files
committed
feature(pu): adapt to npu
1 parent d0b21d0 commit 205cd61

File tree

4 files changed

+219
-12
lines changed

4 files changed

+219
-12
lines changed

ding/policy/base_policy.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ding.model import create_model
1010
from ding.utils import import_module, allreduce, allreduce_with_indicator, broadcast, get_rank, allreduce_async, \
1111
synchronize, deep_merge_dicts, POLICY_REGISTRY
12+
from ding.torch_utils import auto_device_init, move_to_device
1213

1314

1415
class Policy(ABC):
@@ -83,8 +84,12 @@ def default_config(cls: type) -> EasyDict:
8384
config = dict(
8485
# (bool) Whether the learning policy is the same as the collecting data policy (on-policy).
8586
on_policy=False,
86-
# (bool) Whether to use cuda in policy.
87+
# (bool) Whether to use cuda in policy (deprecated, use 'device' instead).
8788
cuda=False,
89+
# (str) Device to use for policy. Can be 'auto', 'cuda', 'npu', or 'cpu'.
90+
# 'auto' will automatically detect NPU > GPU > CPU.
91+
# If not specified, will use 'cuda' config for backward compatibility.
92+
device='auto',
8893
# (bool) Whether to use data parallel multi-gpu mode in policy.
8994
multi_gpu=False,
9095
# (bool) Whether to synchronize update the model parameters after allreduce the gradients of model parameters.
@@ -136,25 +141,42 @@ def __init__(
136141

137142
if len(set(self._enable_field).intersection(set(['learn', 'collect', 'eval']))) > 0:
138143
model = self._create_model(cfg, model)
139-
self._cuda = cfg.cuda and torch.cuda.is_available()
144+
145+
# Device initialization with auto-detection support for NPU/GPU/CPU
146+
# Backward compatibility: if 'device' not in cfg, use 'cuda' config
147+
if hasattr(cfg, 'device') and cfg.device is not None:
148+
# New way: use 'device' config for auto-detection or explicit setting
149+
cfg_device = cfg.device
150+
else:
151+
# Legacy way: convert 'cuda' boolean to device string
152+
cfg_device = 'cuda' if (hasattr(cfg, 'cuda') and cfg.cuda) else 'cpu'
153+
140154
# now only support multi-gpu for only enable learn mode
141155
if len(set(self._enable_field).intersection(set(['learn']))) > 0:
142156
multi_gpu = self._cfg.multi_gpu
143157
self._rank = get_rank() if multi_gpu else 0
144-
if self._cuda:
145-
# model.cuda() is an in-place operation.
146-
model.cuda()
158+
else:
159+
self._rank = 0
160+
161+
# Auto-detect or set device
162+
self._device_type, self._use_accelerator, self._device = auto_device_init(cfg_device, self._rank)
163+
164+
# Keep backward compatibility with _cuda attribute
165+
self._cuda = self._use_accelerator and self._device_type == 'cuda'
166+
167+
# Move model to the detected/configured device
168+
if self._use_accelerator:
169+
move_to_device(model, self._device_type, self._rank)
170+
171+
# Multi-GPU initialization
172+
if len(set(self._enable_field).intersection(set(['learn']))) > 0:
173+
multi_gpu = self._cfg.multi_gpu
147174
if multi_gpu:
148175
bp_update_sync = self._cfg.bp_update_sync
149176
self._bp_update_sync = bp_update_sync
150177
self._init_multi_gpu_setting(model, bp_update_sync)
151-
else:
152-
self._rank = 0
153-
if self._cuda:
154-
# model.cuda() is an in-place operation.
155-
model.cuda()
178+
156179
self._model = model
157-
self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu'
158180
else:
159181
self._cuda = False
160182
self._rank = 0

ding/torch_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
from .dataparallel import DataParallel
1313
from .reshape_helper import fold_batch, unfold_batch, unsqueeze_repeat
1414
from .parameter import NonegativeParameter, TanhParameter
15+
from .device_helper import get_available_device, get_device_count, move_to_device, get_device_string, \
16+
auto_device_init, is_npu_available, is_cuda_available

ding/torch_utils/device_helper.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""
2+
Copyright 2020 Sensetime X-lab. All Rights Reserved.
3+
4+
Device helper utilities for automatic detection of NPU and GPU devices.
5+
Supports Huawei Ascend NPU (torch_npu) and NVIDIA GPU (torch.cuda).
6+
"""
7+
8+
import torch
9+
from typing import Tuple, Optional
10+
import logging
11+
12+
# Try to import torch_npu for Huawei NPU support
13+
try:
14+
import torch_npu
15+
TORCH_NPU_AVAILABLE = True
16+
except ImportError:
17+
TORCH_NPU_AVAILABLE = False
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
def get_available_device() -> Tuple[str, bool]:
23+
"""
24+
Overview:
25+
Automatically detect the available device (NPU or GPU or CPU).
26+
Priority: NPU > GPU > CPU
27+
Returns:
28+
- device_type (:obj:`str`): Device type string, one of 'npu', 'cuda', 'cpu'
29+
- is_accelerator (:obj:`bool`): Whether an accelerator (NPU/GPU) is available
30+
Examples:
31+
>>> device_type, is_accelerator = get_available_device()
32+
>>> print(f"Using device: {device_type}")
33+
"""
34+
# Check for NPU first (Huawei Ascend)
35+
if TORCH_NPU_AVAILABLE and torch.npu.is_available():
36+
npu_count = torch.npu.device_count()
37+
logger.info(f"Detected {npu_count} NPU device(s), using NPU")
38+
return 'npu', True
39+
40+
# Check for CUDA GPU
41+
if torch.cuda.is_available():
42+
gpu_count = torch.cuda.device_count()
43+
logger.info(f"Detected {gpu_count} CUDA GPU device(s), using GPU")
44+
return 'cuda', True
45+
46+
# Fallback to CPU
47+
logger.info("No NPU or GPU detected, using CPU")
48+
return 'cpu', False
49+
50+
51+
def get_device_count(device_type: str) -> int:
52+
"""
53+
Overview:
54+
Get the number of available devices for the specified device type.
55+
Arguments:
56+
- device_type (:obj:`str`): Device type, one of 'npu', 'cuda', 'cpu'
57+
Returns:
58+
- count (:obj:`int`): Number of available devices
59+
"""
60+
if device_type == 'npu' and TORCH_NPU_AVAILABLE:
61+
return torch.npu.device_count()
62+
elif device_type == 'cuda':
63+
return torch.cuda.device_count()
64+
else:
65+
return 1 # CPU always has 1 "device"
66+
67+
68+
def move_to_device(model: torch.nn.Module, device_type: str, rank: int = 0) -> torch.nn.Module:
69+
"""
70+
Overview:
71+
Move a PyTorch model to the specified device.
72+
Supports NPU, CUDA, and CPU devices.
73+
Arguments:
74+
- model (:obj:`torch.nn.Module`): The model to move
75+
- device_type (:obj:`str`): Device type, one of 'npu', 'cuda', 'cpu'
76+
- rank (:obj:`int`): Device rank for multi-device setups
77+
Returns:
78+
- model (:obj:`torch.nn.Module`): The model moved to the device (in-place operation)
79+
"""
80+
if device_type == 'npu' and TORCH_NPU_AVAILABLE:
81+
device_count = torch.npu.device_count()
82+
device_id = rank % device_count if device_count > 0 else 0
83+
model.npu(device_id)
84+
logger.debug(f"Moved model to NPU device {device_id}")
85+
elif device_type == 'cuda':
86+
device_count = torch.cuda.device_count()
87+
device_id = rank % device_count if device_count > 0 else 0
88+
model.cuda(device_id)
89+
logger.debug(f"Moved model to CUDA device {device_id}")
90+
# CPU case: no need to move
91+
return model
92+
93+
94+
def get_device_string(device_type: str, rank: int = 0) -> str:
95+
"""
96+
Overview:
97+
Get the device string for PyTorch tensor operations.
98+
Arguments:
99+
- device_type (:obj:`str`): Device type, one of 'npu', 'cuda', 'cpu'
100+
- rank (:obj:`int`): Device rank for multi-device setups
101+
Returns:
102+
- device_str (:obj:`str`): Device string like 'npu:0', 'cuda:0', or 'cpu'
103+
"""
104+
if device_type in ['npu', 'cuda']:
105+
device_count = get_device_count(device_type)
106+
device_id = rank % device_count if device_count > 0 else 0
107+
return f'{device_type}:{device_id}'
108+
else:
109+
return 'cpu'
110+
111+
112+
def auto_device_init(cfg_device: Optional[str], rank: int = 0) -> Tuple[str, bool, str]:
113+
"""
114+
Overview:
115+
Initialize device settings based on config.
116+
Supports automatic detection, explicit device type, or legacy 'cuda' boolean.
117+
Arguments:
118+
- cfg_device (:obj:`Optional[str]`): Device configuration from config.
119+
Can be 'auto', 'npu', 'cuda', 'cpu', or None (defaults to 'auto')
120+
- rank (:obj:`int`): Device rank for multi-device setups
121+
Returns:
122+
- device_type (:obj:`str`): Detected device type ('npu', 'cuda', or 'cpu')
123+
- use_accelerator (:obj:`bool`): Whether an accelerator is being used
124+
- device_str (:obj:`str`): Full device string for PyTorch operations
125+
Examples:
126+
>>> device_type, use_accelerator, device_str = auto_device_init('auto')
127+
>>> # Returns ('npu', True, 'npu:0') if NPU available
128+
>>> # Returns ('cuda', True, 'cuda:0') if GPU available
129+
>>> # Returns ('cpu', False, 'cpu') otherwise
130+
"""
131+
# Default to auto detection if not specified
132+
if cfg_device is None or cfg_device == 'auto':
133+
device_type, use_accelerator = get_available_device()
134+
else:
135+
# Explicit device type specified
136+
device_type = cfg_device.lower()
137+
138+
# Validate the device type is available
139+
if device_type == 'npu':
140+
if TORCH_NPU_AVAILABLE and torch.npu.is_available():
141+
use_accelerator = True
142+
logger.info("Using NPU as explicitly configured")
143+
else:
144+
logger.warning("NPU requested but not available, falling back to CPU")
145+
device_type = 'cpu'
146+
use_accelerator = False
147+
elif device_type == 'cuda':
148+
if torch.cuda.is_available():
149+
use_accelerator = True
150+
logger.info("Using CUDA GPU as explicitly configured")
151+
else:
152+
logger.warning("CUDA requested but not available, falling back to CPU")
153+
device_type = 'cpu'
154+
use_accelerator = False
155+
else:
156+
# CPU or any other value
157+
device_type = 'cpu'
158+
use_accelerator = False
159+
logger.info("Using CPU as configured")
160+
161+
device_str = get_device_string(device_type, rank)
162+
163+
return device_type, use_accelerator, device_str
164+
165+
166+
def is_npu_available() -> bool:
167+
"""
168+
Overview:
169+
Check if Huawei NPU is available.
170+
Returns:
171+
- available (:obj:`bool`): True if NPU is available
172+
"""
173+
return TORCH_NPU_AVAILABLE and torch.npu.is_available()
174+
175+
176+
def is_cuda_available() -> bool:
177+
"""
178+
Overview:
179+
Check if NVIDIA CUDA GPU is available.
180+
Returns:
181+
- available (:obj:`bool`): True if CUDA is available
182+
"""
183+
return torch.cuda.is_available()

dizoo/classic_control/cartpole/config/cartpole_ppo_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
stop_value=195,
1010
),
1111
policy=dict(
12-
cuda=False,
12+
device='auto', # Auto-detect NPU > GPU > CPU
1313
action_space='discrete',
1414
model=dict(
1515
obs_shape=4,

0 commit comments

Comments
 (0)