Skip to content

Commit f64a307

Browse files
committed
refactor: support amd detection
Signed-off-by: thxCode <[email protected]>
1 parent 5c631c6 commit f64a307

File tree

11 files changed

+352
-101
lines changed

11 files changed

+352
-101
lines changed

gpustack_runtime/cmds/detector.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,12 @@ def format_devices_table(devs: Devices) -> str:
114114
# Device rows
115115
device_lines = []
116116
for dev in devs:
117-
row_data = [
118-
dev.index,
117+
row_data: list[str] = [
118+
str(dev.index),
119119
dev.name if dev.name else "N/A",
120-
f"{dev.memory_used}MiB / {dev.memory}MiB"
121-
if dev.memory and dev.memory_used
122-
else "N/A",
123-
f"{dev.cores_utilization}%" if dev.cores_utilization else "N/A",
124-
f"{dev.temperature}C" if dev.temperature else "N/A",
120+
f"{dev.memory_used}MiB / {dev.memory}MiB",
121+
f"{dev.cores_utilization}%",
122+
f"{dev.temperature}C" if dev.temperature is not None else "N/A",
125123
dev.compute_capability if dev.compute_capability else "N/A",
126124
]
127125

gpustack_runtime/detector/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
supported_backends,
1111
supported_manufacturers,
1212
)
13+
from .amd import AMDDetector
1314
from .ascend import AscendDetector
1415
from .nvidia import NVIDIADetector
1516

1617
detectors: list[Detector] = [
17-
NVIDIADetector(),
18+
AMDDetector(),
1819
AscendDetector(),
20+
NVIDIADetector(),
1921
]
2022

2123

gpustack_runtime/detector/__types__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,34 +149,34 @@ class Device:
149149
"""
150150
UUID of the device.
151151
"""
152-
driver_version: str = ""
152+
driver_version: str | None = None
153153
"""
154154
Driver version of the device.
155155
"""
156156
driver_version_tuple: list[int | str] | None = None
157157
"""
158158
Driver version tuple of the device.
159-
None if `driver_version` is blank.
159+
None if `driver_version` is missed.
160160
"""
161-
runtime_version: str = ""
161+
runtime_version: str | None = None
162162
"""
163163
Runtime version of the device.
164164
"""
165165
runtime_version_tuple: list[int | str] | None = None
166166
"""
167167
Runtime version tuple of the device.
168-
None if `runtime_version` is blank.
168+
None if `runtime_version` is missed.
169169
"""
170-
compute_capability: str = ""
170+
compute_capability: str | None = None
171171
"""
172172
Compute capability of the device.
173173
"""
174174
compute_capability_tuple: list[int | str] | None = None
175175
"""
176176
Compute capability tuple of the device.
177-
None if `compute_capability` is blank.
177+
None if `compute_capability` is missed.
178178
"""
179-
cores: int = 0
179+
cores: int | None = None
180180
"""
181181
Total cores of the device.
182182
"""
@@ -196,15 +196,15 @@ class Device:
196196
"""
197197
Memory utilization of the device in percentage.
198198
"""
199-
temperature: int = 0
199+
temperature: int | None = None
200200
"""
201201
Temperature of the device in Celsius.
202202
"""
203-
power: int = 0
203+
power: int | None = None
204204
"""
205205
Power consumption of the device in Watts.
206206
"""
207-
power_used: int = 0
207+
power_used: int | None = None
208208
"""
209209
Used power of the device in Watts.
210210
"""

gpustack_runtime/detector/__utils__.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class PCIDevice:
1111
"""
1212
Vendor ID of the PCI device.
1313
"""
14-
path: Path
14+
path: str
1515
"""
1616
Path to the PCI device in sysfs.
1717
"""
@@ -80,7 +80,7 @@ def get_pci_devices(
8080
pci_devices.append(
8181
PCIDevice(
8282
vendor=dev_vendor,
83-
path=dev_path,
83+
path=str(dev_path),
8484
address=dev_address,
8585
class_=dev_class,
8686
config=dev_config,
@@ -92,7 +92,7 @@ def get_pci_devices(
9292

9393
@dataclass
9494
class DeviceFile:
95-
path: Path
95+
path: str
9696
"""
9797
Path to the device file.
9898
"""
@@ -102,7 +102,7 @@ class DeviceFile:
102102
"""
103103

104104

105-
def get_device_files(pattern: str, directory: Path = Path("/dev")) -> list[DeviceFile]:
105+
def get_device_files(pattern: str, directory: Path | str = "/dev") -> list[DeviceFile]:
106106
r"""
107107
Get device files with the given pattern.
108108
@@ -123,24 +123,27 @@ def get_device_files(pattern: str, directory: Path = Path("/dev")) -> list[Devic
123123
msg = "Pattern must include a regex group for the number, e.g nvidia(?P<number>\\d+)."
124124
raise ValueError(msg)
125125

126+
if isinstance(directory, str):
127+
directory = Path(directory)
128+
126129
device_files = []
127130
if not directory.exists():
128131
return device_files
129132

130133
regex = re.compile(f"^{directory!s}/{pattern}$")
131-
for path in directory.iterdir():
132-
matched = regex.match(str(path))
134+
for file_path in directory.iterdir():
135+
matched = regex.match(str(file_path))
133136
if not matched:
134137
continue
135-
number = matched.group("number")
138+
file_number = matched.group("number")
136139
try:
137-
number = int(number)
140+
file_number = int(file_number)
138141
except ValueError:
139-
number = None
142+
file_number = None
140143
device_files.append(
141144
DeviceFile(
142-
path=path,
143-
number=number,
145+
path=str(file_path),
146+
number=file_number,
144147
),
145148
)
146149

gpustack_runtime/detector/amd.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from __future__ import annotations
2+
3+
import contextlib
4+
import logging
5+
from functools import lru_cache
6+
7+
from .. import envs
8+
from .__types__ import Detector, Device, Devices, ManufacturerEnum
9+
from .__utils__ import PCIDevice, get_device_files, get_pci_devices
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class AMDDetector(Detector):
15+
"""
16+
Detect AMD GPUs.
17+
"""
18+
19+
@staticmethod
20+
@lru_cache
21+
def is_supported() -> bool:
22+
"""
23+
Check if the AMD detector is supported.
24+
25+
Returns:
26+
True if supported, False otherwise.
27+
28+
"""
29+
supported = False
30+
if envs.GPUSTACK_RUNTIME_DETECT.lower() not in ("auto", "amd"):
31+
logger.debug("AMD detection is disabled by environment variable")
32+
return supported
33+
34+
pci_devs = AMDDetector.detect_pci_devices()
35+
if not pci_devs:
36+
logger.debug("No AMD PCI devices found")
37+
return supported
38+
39+
try:
40+
import amdsmi as pyamdsmi # noqa: PLC0415
41+
except ImportError:
42+
if logger.isEnabledFor(logging.DEBUG):
43+
logger.exception("amdsmi module is not installed")
44+
return supported
45+
46+
try:
47+
pyamdsmi.amdsmi_init()
48+
pyamdsmi.amdsmi_shut_down()
49+
supported = True
50+
except pyamdsmi.AmdSmiException:
51+
if logger.isEnabledFor(logging.DEBUG):
52+
logger.exception("Failed to initialize AMD SMI")
53+
54+
return supported
55+
56+
@staticmethod
57+
@lru_cache
58+
def detect_pci_devices() -> dict[str, PCIDevice] | None:
59+
pci_devs = get_pci_devices(vendor="0x1002")
60+
if not pci_devs:
61+
return None
62+
return {dev.address: dev for dev in pci_devs}
63+
64+
def __init__(self):
65+
super().__init__(ManufacturerEnum.AMD)
66+
67+
def detect(self) -> Devices | None:
68+
"""
69+
Detect AMD GPUs.
70+
71+
Returns:
72+
A list of detected AMD GPU devices,
73+
or None if detection fails.
74+
75+
"""
76+
if not self.is_supported():
77+
return None
78+
79+
try:
80+
import amdsmi as pyamdsmi # noqa: PLC0415
81+
except ImportError:
82+
if logger.isEnabledFor(logging.DEBUG):
83+
logger.exception("amdsmi module is not installed")
84+
return None
85+
86+
ret: Devices = []
87+
88+
try:
89+
pyamdsmi.amdsmi_init()
90+
91+
sys_runtime_ver = None
92+
sys_runtime_ver_t = None
93+
94+
devs = pyamdsmi.amdsmi_get_processor_handles()
95+
dev_files = None
96+
for dev_idx, dev in enumerate(devs):
97+
dev_index = dev_idx
98+
if envs.GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY:
99+
if dev_files is None:
100+
dev_files = get_device_files(
101+
pattern=r"card(?P<number>\d+)",
102+
directory="/dev/dri",
103+
)
104+
if len(dev_files) > dev_idx:
105+
dev_file = dev_files[dev_idx]
106+
if dev_file.number is not None:
107+
dev_index = dev_file.number - 1
108+
109+
dev_uuid = pyamdsmi.amdsmi_get_gpu_device_uuid(dev)
110+
111+
dev_gpu_driver_info = pyamdsmi.amdsmi_get_gpu_driver_info(dev)
112+
dev_driver_ver = dev_gpu_driver_info.get("driver_version")
113+
dev_driver_ver_t = [
114+
int(v) if v.isdigit() else v for v in dev_driver_ver.split(".")
115+
]
116+
117+
dev_gpu_board_info = pyamdsmi.amdsmi_get_gpu_board_info(dev)
118+
dev_name = "AMD " + dev_gpu_board_info.get("product_name")
119+
120+
dev_gpu_metrics_info = pyamdsmi.amdsmi_get_gpu_metrics_info(dev)
121+
dev_cores_util = dev_gpu_metrics_info.get("average_gfx_activity", 0)
122+
dev_gpu_vram_usage = pyamdsmi.amdsmi_get_gpu_vram_usage(dev)
123+
dev_mem = dev_gpu_vram_usage.get("vram_total")
124+
dev_mem_used = dev_gpu_vram_usage.get("vram_used")
125+
dev_temp = dev_gpu_metrics_info.get("temperature_hotspot", 0)
126+
127+
dev_power_info = pyamdsmi.amdsmi_get_power_info(dev)
128+
dev_power = dev_power_info.get("power_limit", 0) // 1000000 # uW to W
129+
dev_power_used = (
130+
dev_power_info.get("current_socket_power")
131+
if dev_power_info.get("current_socket_power", "N/A") != "N/A"
132+
else dev_power_info.get("average_socket_power", 0)
133+
)
134+
135+
dev_compute_partition = None
136+
with contextlib.suppress(pyamdsmi.AmdSmiException):
137+
dev_compute_partition = pyamdsmi.amdsmi_get_gpu_compute_partition(
138+
dev,
139+
)
140+
141+
dev_appendix = {
142+
"vgpu": dev_compute_partition is not None,
143+
}
144+
145+
ret.append(
146+
Device(
147+
manufacturer=self.manufacturer,
148+
index=dev_index,
149+
name=dev_name,
150+
uuid=dev_uuid,
151+
driver_version=dev_driver_ver,
152+
driver_version_tuple=dev_driver_ver_t,
153+
runtime_version=sys_runtime_ver,
154+
runtime_version_tuple=sys_runtime_ver_t,
155+
cores_utilization=dev_cores_util,
156+
memory=dev_mem,
157+
memory_used=dev_mem_used,
158+
memory_utilization=(
159+
(dev_mem_used * 100 // dev_mem) if dev_mem > 0 else 0
160+
),
161+
temperature=dev_temp,
162+
power=dev_power,
163+
power_used=dev_power_used,
164+
appendix=dev_appendix,
165+
),
166+
)
167+
except pyamdsmi.AmdSmiException:
168+
if logger.isEnabledFor(logging.DEBUG):
169+
logger.exception("Failed to fetch devices")
170+
raise
171+
except Exception:
172+
if logger.isEnabledFor(logging.DEBUG):
173+
logger.exception("Failed to process devices fetching")
174+
raise
175+
finally:
176+
pyamdsmi.amdsmi_shut_down()
177+
178+
return ret

0 commit comments

Comments
 (0)