Skip to content

Commit 4cbdce5

Browse files
committed
refactor: make the cuda/rocmsmi initialization asap
Signed-off-by: thxCode <[email protected]>
1 parent 28be0f1 commit 4cbdce5

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

gpustack_runtime/detector/amd.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def detect(self) -> Devices | None:
8686
hsa_agents = {hsa_agent.uuid: hsa_agent for hsa_agent in pyhsa.get_agents()}
8787

8888
pyamdsmi.amdsmi_init()
89+
try:
90+
pyrocmsmi.rsmi_init()
91+
except pyrocmsmi.ROCMSMIError:
92+
debug_log_exception(logger, "Failed to initialize ROCm SMI")
8993

9094
sys_runtime_ver_original = pyrocmcore.getROCmVersion()
9195
sys_runtime_ver = get_brief_version(sys_runtime_ver_original)
@@ -106,7 +110,6 @@ def detect(self) -> Devices | None:
106110
dev_card_id = dev_kfd_info.get("node_id")
107111
else:
108112
with contextlib.suppress(pyrocmsmi.ROCMSMIError):
109-
pyrocmsmi.rsmi_init()
110113
dev_card_id = pyrocmsmi.rsmi_dev_node_id_get(dev_idx)
111114

112115
dev_gpudev_info = None
@@ -130,7 +133,6 @@ def detect(self) -> Devices | None:
130133
dev_name = dev_gpu_asic_info.get("market_name")
131134
dev_cc = None
132135
with contextlib.suppress(pyrocmsmi.ROCMSMIError):
133-
pyrocmsmi.rsmi_init()
134136
dev_cc = pyrocmsmi.rsmi_dev_target_graphics_version_get(dev_idx)
135137

136138
dev_cores = None
@@ -147,7 +149,6 @@ def detect(self) -> Devices | None:
147149
dev_temp = dev_gpu_metrics_info.get("temperature_hotspot", 0)
148150
except pyamdsmi.AmdSmiException:
149151
with contextlib.suppress(pyrocmsmi.ROCMSMIError):
150-
pyrocmsmi.rsmi_init()
151152
dev_cores_util = pyrocmsmi.rsmi_dev_busy_percent_get(dev_idx)
152153
dev_temp = pyrocmsmi.rsmi_dev_temp_metric_get(dev_idx)
153154
if dev_cores_util is None:
@@ -166,7 +167,6 @@ def detect(self) -> Devices | None:
166167
dev_mem_used = dev_gpu_vram_usage.get("vram_used")
167168
except pyamdsmi.AmdSmiException:
168169
with contextlib.suppress(pyrocmsmi.ROCMSMIError):
169-
pyrocmsmi.rsmi_init()
170170
dev_mem = byte_to_mebibyte( # byte to MiB
171171
pyrocmsmi.rsmi_dev_memory_total_get(dev_idx),
172172
)
@@ -188,7 +188,6 @@ def detect(self) -> Devices | None:
188188
)
189189
except pyamdsmi.AmdSmiException:
190190
with contextlib.suppress(pyrocmsmi.ROCMSMIError):
191-
pyrocmsmi.rsmi_init()
192191
dev_power = pyrocmsmi.rsmi_dev_power_cap_get(dev_idx)
193192
dev_power_used = pyrocmsmi.rsmi_dev_power_get(dev_idx)
194193

gpustack_runtime/detector/nvidia.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def detect(self) -> Devices | None: # noqa: PLR0915
9191
pci_devs = NVIDIADetector.detect_pci_devices()
9292

9393
pynvml.nvmlInit()
94+
try:
95+
pycuda.cuInit()
96+
except pycuda.CUDAError:
97+
debug_log_exception(logger, "Failed to initialize CUDA")
9498

9599
sys_driver_ver = pynvml.nvmlSystemGetDriverVersion()
96100

@@ -133,7 +137,6 @@ def detect(self) -> Devices | None: # noqa: PLR0915
133137

134138
dev_cores = None
135139
with contextlib.suppress(pycuda.CUDAError):
136-
pycuda.cuInit()
137140
dev_gpudev = pycuda.cuDeviceGet(dev_idx)
138141
dev_cores = pycuda.cuDeviceGetAttribute(
139142
dev_gpudev,

0 commit comments

Comments
 (0)