Skip to content

Commit 1ef6cdd

Browse files
committed
feat: enrich gpu information using sysfs
1 parent 12e6154 commit 1ef6cdd

File tree

1 file changed

+155
-13
lines changed

1 file changed

+155
-13
lines changed

scripts/list_gpus.sh

Lines changed: 155 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,51 @@ get_nvidia_gpus() {
2020
init: {
2121
deviceRequests: {
2222
Driver: "nvidia",
23-
DeviceIDs: [$uuid]
23+
Devices: [$uuid]
2424
}
2525
}
2626
}'
2727
done
2828
fi
2929
}
3030

31+
# Declare the associative array (hashmap) globally
32+
declare -A gpu_map
33+
34+
map_pci_to_primary() {
35+
# Iterate over all card nodes in /sys/class/drm
36+
# We filter for 'card*' to ignore 'renderD*' nodes for the primary map
37+
for card_path in /sys/class/drm/card*; do
38+
39+
# logical check to ensure the glob matched a file
40+
[ -e "$card_path" ] || continue
41+
42+
# Resolve the symlink to the actual PCI device directory
43+
# Example result: /sys/devices/pci0000:00/.../0000:03:00.0
44+
real_device_path=$(readlink -f "$card_path/device")
45+
46+
# The last part of that path is the PCI ID (e.g., 0000:03:00.0)
47+
pci_id=$(basename "$real_device_path")
48+
49+
# The last part of the card_path is the card name (e.g., card0)
50+
card_name=$(basename "$card_path")
51+
52+
# Store in the hashmap
53+
# Key: PCI ID, Value: /dev/dri/cardX
54+
gpu_map["$pci_id"]="/dev/dri/$card_name"
55+
done
56+
}
57+
58+
59+
3160
# Function to check for other GPUs (AMD, Intel, etc.) via lspci
3261
get_generic_gpus() {
3362
# Check if lspci is available
3463
if ! command -v lspci &> /dev/null; then
3564
return
3665
fi
3766

67+
map_pci_to_primary
3868
# Iterate over VGA and 3D controllers
3969
lspci -mm -n -d ::0300 | while read -r line; do process_pci_line "$line"; done
4070
lspci -mm -n -d ::0302 | while read -r line; do process_pci_line "$line"; done
@@ -44,10 +74,10 @@ process_pci_line() {
4474
line="$1"
4575

4676
slot=$(echo "$line" | awk '{print $1}')
47-
vendor_id=$(echo "$line" | awk '{print $3}' | tr -d '"')
77+
vendor_id_hex=$(echo "$line" | awk '{print $3}' | tr -d '"')
4878

4979
# We want to exclude NVIDIA here if we already handled them via nvidia-smi.
50-
if [[ "$vendor_id" == "10de" ]] && command -v nvidia-smi &> /dev/null; then
80+
if [[ "$vendor_id_hex" == "10de" ]] && command -v nvidia-smi &> /dev/null; then
5181
return
5282
fi
5383

@@ -61,24 +91,130 @@ process_pci_line() {
6191

6292
# Determine driver
6393
driver=""
64-
if [[ "$vendor_id" == "1002" ]]; then # AMD
65-
driver="amdgpu"
94+
if [[ "$vendor_id_hex" == "1002" ]]; then # AMD
95+
driver="amdgpu"
96+
elif [[ "$vendor_id_hex" = "8086" ]]; then # Intel
97+
driver="intel"
98+
fi
99+
100+
device_id=""
101+
card_path=""
102+
if [ -n "${gpu_map[$pci_id]}" ]; then
103+
# Get device id from /sys/class/drm map
104+
device_id="${gpu_map[$pci_id]}" # e.g. /dev/dri/card0
105+
# Reconstruct sysfs card path from the device path
106+
# device_id is /dev/dri/cardX, we want /sys/class/drm/cardX
107+
card_name=$(basename "$device_id")
108+
card_path="/sys/class/drm/$card_name"
109+
else
110+
# If it can't be found, default to pci id
111+
device_id="${pci_id}"
112+
fi
113+
114+
local devices=()
115+
local binds=()
116+
local cap_add=()
117+
local group_add=()
118+
local ipc_mode="null"
119+
local shm_size="null"
120+
local security_opt="null"
121+
122+
# Only perform detailed checks if we found the card path
123+
if [ -n "$card_path" ] && [ -e "$card_path" ]; then
124+
125+
# Resolve real device path for getting sibling render node
126+
local real_device_path=$(readlink -f "$card_path/device")
127+
local render_name=""
128+
if [ -d "$real_device_path/drm" ]; then
129+
render_name=$(ls "$real_device_path/drm" | grep "^renderD" | head -n 1)
130+
fi
131+
132+
case "$vendor_id_hex" in
133+
"1002") # AMD (0x1002)
134+
# Devices
135+
[ -e "/dev/dxg" ] && devices+=("/dev/dxg")
136+
devices+=("$device_id") # /dev/dri/cardX
137+
138+
# Binds
139+
[ -e "/usr/lib/wsl/lib/libdxcore.so" ] && \
140+
binds+=("/usr/lib/wsl/lib/libdxcore.so:/usr/lib/libdxcore.so")
141+
[ -e "/opt/rocm/lib/libhsa-runtime64.so.1" ] && \
142+
binds+=("/opt/rocm/lib/libhsa-runtime64.so.1:/opt/rocm/lib/libhsa-runtime64.so.1")
143+
144+
# Configs
145+
cap_add+=("SYS_PTRACE")
146+
ipc_mode="\"host\""
147+
shm_size="8589934592"
148+
# SecurityOpt is a JSON object
149+
security_opt='{"seccomp": "unconfined"}'
150+
;;
151+
152+
"8086") # Intel (0x8086)
153+
# Devices
154+
[ -n "$render_name" ] && devices+=("/dev/dri/$render_name")
155+
devices+=("$device_id")
156+
157+
# Configs
158+
group_add+=("video" "render")
159+
cap_add+=("SYS_ADMIN")
160+
;;
161+
esac
162+
else
163+
# Fallback if we don't have the card path mapped, but still want to add the primary device if applicable
164+
# This preserves behavior for devices that might not map correctly but are enumerated
165+
if [[ "$vendor_id_hex" == "1002" ]] || [[ "$vendor_id_hex" == "8086" ]]; then
166+
if [[ "$device_id" == /dev/* ]]; then
167+
devices+=("$device_id")
168+
fi
169+
fi
170+
fi
171+
172+
173+
# --- Construct JSON ---
174+
175+
# Helper to convert bash arrays using jq
176+
# (re-using the logic, but localized vars)
177+
json_devices=$(printf '%s\n' "${devices[@]}" | jq -R . | jq -s . | jq 'map(select(length > 0))')
178+
json_binds=$(printf '%s\n' "${binds[@]}" | jq -R . | jq -s . | jq 'map(select(length > 0))')
179+
json_cap=$(printf '%s\n' "${cap_add[@]}" | jq -R . | jq -s . | jq 'map(select(length > 0))')
180+
json_group=$(printf '%s\n' "${group_add[@]}" | jq -R . | jq -s . | jq 'map(select(length > 0))')
181+
182+
# If Devices array is empty, ensure at least the ID we found is there (unless it was already added)
183+
# Using 'index' to check if device_id is present is tricky with jq on the fly,
184+
# but standardizing on what we found is safer.
185+
# If the detailed logic above didn't populate devices (e.g. unknown vendor), we fall back to just the ID.
186+
if [ "$(echo "$json_devices" | jq length)" -eq 0 ]; then
187+
json_devices="[\"$device_id\"]"
66188
fi
67189

68-
# Construct JSON
190+
69191
jq -c -n \
70192
--arg desc "$description" \
71193
--arg driver "$driver" \
72-
--arg pci_id "$pci_id" \
194+
--arg device_id "$device_id" \
195+
--argjson dev "$json_devices" \
196+
--argjson bind "$json_binds" \
197+
--argjson cap "$json_cap" \
198+
--argjson group "$json_group" \
199+
--argjson sec "$security_opt" \
200+
--argjson shm "$shm_size" \
201+
--argjson ipc "$ipc_mode" \
73202
'{
74203
description: $desc,
75204
init: {
76205
deviceRequests: {
77206
Driver: (if $driver != "" then $driver else null end),
78-
DeviceIDs: [$pci_id]
79-
}
207+
Devices: $dev,
208+
Capabilities: [["gpu"]]
209+
},
210+
Binds: $bind,
211+
CapAdd: $cap,
212+
GroupAdd: $group,
213+
SecurityOpt: $sec,
214+
ShmSize: $shm,
215+
IpcMode: $ipc
80216
}
81-
}'
217+
} | del(.. | select(. == null)) | del(.. | select(. == []))'
82218
}
83219

84220
# Function to get all GPUs in JSON array format
@@ -95,10 +231,16 @@ get_all_gpus_json() {
95231
init: {
96232
deviceRequests: {
97233
Driver: .[0].init.deviceRequests.Driver,
98-
DeviceIDs: (map(.init.deviceRequests.DeviceIDs[]) | unique),
234+
(if .[0].init.deviceRequests.Driver == "nvidia" then "DeviceIDs" else "Devices" end): (map(.init.deviceRequests.Devices[]?) | unique),
99235
Capabilities: [["gpu"]]
100-
}
101-
}
236+
},
237+
Binds: (map(.init.Binds[]?) | unique),
238+
CapAdd: (map(.init.CapAdd[]?) | unique),
239+
GroupAdd: (map(.init.GroupAdd[]?) | unique),
240+
SecurityOpt: .[0].init.SecurityOpt,
241+
ShmSize: .[0].init.ShmSize,
242+
IpcMode: .[0].init.IpcMode
243+
} | del(.. | select(. == null)) | del(.. | select(. == []))
102244
}) | map(if .init.deviceRequests.Driver == null then del(.init.deviceRequests.Driver) else . end)
103245
'
104246
}

0 commit comments

Comments
 (0)