Skip to content

Commit 196fa8a

Browse files
authored
fix extract_device_id bug when there is only one device (#74428)
1 parent 90e4213 commit 196fa8a

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

python/paddle/device/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -635,10 +635,18 @@ def extract_device_id(device: _CustomPlaceLike, op_name: str) -> int:
635635
else:
636636
device_type = None
637637
available_custom_devices = core.get_available_custom_device()
638-
for d in available_custom_devices:
639-
dev_type, dev_id = d.split(':')
640-
if int(dev_id) == device:
641-
device_type = dev_type
638+
if len(available_custom_devices) == 1:
639+
if device == 0:
640+
device_type = available_custom_devices[0]
641+
else:
642+
raise ValueError(
643+
f"Device id {device} not found in available_custom_devices: [{available_custom_devices[0]}:0]"
644+
)
645+
else:
646+
for d in available_custom_devices:
647+
dev_type, dev_id = d.split(':')
648+
if int(dev_id) == device:
649+
device_type = dev_type
642650
if device_type is None:
643651
raise ValueError(
644652
f"Device id {device} not found in available_custom_devices: {available_custom_devices}"

python/paddle/device/cuda/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,18 @@ def extract_cuda_device_id(device: _CudaPlaceLike, op_name: str) -> int:
212212
else:
213213
device_type = None
214214
available_custom_devices = core.get_available_custom_device()
215-
for d in available_custom_devices:
216-
dev_type, dev_id = d.split(':')
217-
if int(dev_id) == device:
218-
device_type = dev_type
215+
if len(available_custom_devices) == 1:
216+
if device == 0:
217+
device_type = available_custom_devices[0]
218+
else:
219+
raise ValueError(
220+
f"Device id {device} not found in available_custom_devices: [{available_custom_devices[0]}:0]"
221+
)
222+
else:
223+
for d in available_custom_devices:
224+
dev_type, dev_id = d.split(':')
225+
if int(dev_id) == device:
226+
device_type = dev_type
219227
if device_type is None:
220228
raise ValueError(
221229
f"Device id {device} not found in available_custom_devices: {available_custom_devices}"

0 commit comments

Comments
 (0)