99
1010from .. import envs
1111from .__types__ import Detector , Device , Devices , ManufacturerEnum
12- from .__utils__ import get_pci_devices
12+ from .__utils__ import get_device_files , get_pci_devices
1313
1414logger = logging .getLogger (__name__ )
1515
@@ -82,23 +82,25 @@ def detect(self) -> Devices | None:
8282 dev_runtime_ver = f"{ dev_runtime_ver_t [0 ]} .{ dev_runtime_ver_t [1 ]} "
8383
8484 dev_count = pynvml .nvmlDeviceGetCount ()
85+ dev_files = None
8586 for dev_idx in range (dev_count ):
8687 dev = pynvml .nvmlDeviceGetHandleByIndex (dev_idx )
8788
88- dev_pci_info = pynvml .nvmlDeviceGetPciInfo (dev )
8989 dev_is_vgpu = False
90+ dev_pci_info = pynvml .nvmlDeviceGetPciInfo (dev )
9091 for addr in [dev_pci_info .busIdLegacy , dev_pci_info .busId ]:
9192 if addr in pci_devs :
9293 dev_is_vgpu = _is_vgpu (pci_devs [addr ].config )
9394 break
9495
9596 dev_index = dev_idx
9697 if envs .GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY :
97- dev_index = (
98- dev_pci_info .bus - 1
99- if dev_pci_info .bus > 0
100- else dev_pci_info .bus
101- )
98+ if dev_files is None :
99+ dev_files = get_device_files (pattern = r"nvidia(?P<number>\d+)" )
100+ if len (dev_files ) > dev_idx :
101+ dev_file = dev_files [dev_idx ]
102+ if dev_file .number is not None :
103+ dev_index = dev_file .number
102104 dev_uuid = pynvml .nvmlDeviceGetUUID (dev )
103105 dev_cores = pynvml .nvmlDeviceGetNumGpuCores (dev )
104106 dev_mem = pynvml .nvmlDeviceGetMemoryInfo (dev )
@@ -140,11 +142,7 @@ def detect(self) -> Devices | None:
140142 ret .append (
141143 Device (
142144 manufacturer = self .manufacturer ,
143- indexes = (
144- dev_index
145- if isinstance (dev_index , list )
146- else [dev_index ]
147- ),
145+ index = dev_index ,
148146 name = dev_name ,
149147 uuid = dev_uuid .upper (),
150148 driver_version = sys_driver_ver ,
@@ -176,6 +174,7 @@ def detect(self) -> Devices | None:
176174 for mdev_idx in range (mdev_count ):
177175 mdev = pynvml .nvmlDeviceGetMigDeviceHandleByIndex (dev , mdev_idx )
178176
177+ mdev_index = mdev_idx
179178 mdev_uuid = pynvml .nvmlDeviceGetUUID (mdev )
180179 mdev_mem = pynvml .nvmlDeviceGetMemoryInfo (mdev )
181180 mdev_temp = pynvml .nvmlDeviceGetTemperature (
@@ -192,10 +191,6 @@ def detect(self) -> Devices | None:
192191 mdev_ci_id = pynvml .nvmlDeviceGetComputeInstanceId (mdev )
193192 mdev_appendix ["compute_instance_id" ] = mdev_ci_id
194193
195- mdev_index = mdev_idx
196- if envs .GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY :
197- mdev_index = [mdev_gi_id , mdev_ci_id ]
198-
199194 if not mdev_name :
200195 mdev_attrs = pynvml .nvmlDeviceGetAttributes (mdev )
201196
@@ -272,11 +267,7 @@ def detect(self) -> Devices | None:
272267 ret .append (
273268 Device (
274269 manufacturer = self .manufacturer ,
275- indexes = (
276- mdev_index
277- if isinstance (mdev_index , list )
278- else [mdev_index ]
279- ),
270+ index = mdev_index ,
280271 name = mdev_name ,
281272 uuid = mdev_uuid .upper (),
282273 driver_version = sys_driver_ver ,
0 commit comments