@@ -20,21 +20,51 @@ get_nvidia_gpus() {
2020 init: {
2121 deviceRequests: {
2222 Driver: "nvidia",
23- DeviceIDs : [$uuid]
23+ Devices : [$uuid]
2424 }
2525 }
2626 }'
2727 done
2828 fi
2929}
3030
31+ # Declare the associative array (hashmap) globally
32+ declare -A gpu_map
33+
34+ map_pci_to_primary () {
35+ # Iterate over all card nodes in /sys/class/drm
36+ # We filter for 'card*' to ignore 'renderD*' nodes for the primary map
37+ for card_path in /sys/class/drm/card* ; do
38+
39+ # logical check to ensure the glob matched a file
40+ [ -e " $card_path " ] || continue
41+
42+ # Resolve the symlink to the actual PCI device directory
43+ # Example result: /sys/devices/pci0000:00/.../0000:03:00.0
44+ real_device_path=$( readlink -f " $card_path /device" )
45+
46+ # The last part of that path is the PCI ID (e.g., 0000:03:00.0)
47+ pci_id=$( basename " $real_device_path " )
48+
49+ # The last part of the card_path is the card name (e.g., card0)
50+ card_name=$( basename " $card_path " )
51+
52+ # Store in the hashmap
53+ # Key: PCI ID, Value: /dev/dri/cardX
54+ gpu_map[" $pci_id " ]=" /dev/dri/$card_name "
55+ done
56+ }
57+
58+
59+
3160# Function to check for other GPUs (AMD, Intel, etc.) via lspci
3261get_generic_gpus () {
3362 # Check if lspci is available
3463 if ! command -v lspci & > /dev/null; then
3564 return
3665 fi
3766
67+ map_pci_to_primary
3868 # Iterate over VGA and 3D controllers
3969 lspci -mm -n -d ::0300 | while read -r line; do process_pci_line " $line " ; done
4070 lspci -mm -n -d ::0302 | while read -r line; do process_pci_line " $line " ; done
@@ -44,10 +74,10 @@ process_pci_line() {
4474 line=" $1 "
4575
4676 slot=$( echo " $line " | awk ' {print $1}' )
47- vendor_id =$( echo " $line " | awk ' {print $3}' | tr -d ' "' )
77+ vendor_id_hex =$( echo " $line " | awk ' {print $3}' | tr -d ' "' )
4878
4979 # We want to exclude NVIDIA here if we already handled them via nvidia-smi.
50- if [[ " $vendor_id " == " 10de" ]] && command -v nvidia-smi & > /dev/null; then
80+ if [[ " $vendor_id_hex " == " 10de" ]] && command -v nvidia-smi & > /dev/null; then
5181 return
5282 fi
5383
@@ -61,24 +91,130 @@ process_pci_line() {
6191
6292 # Determine driver
6393 driver=" "
64- if [[ " $vendor_id " == " 1002" ]]; then # AMD
65- driver=" amdgpu"
94+ if [[ " $vendor_id_hex " == " 1002" ]]; then # AMD
95+ driver=" amdgpu"
96+ elif [[ " $vendor_id_hex " = " 8086" ]]; then # Intel
97+ driver=" intel"
98+ fi
99+
100+ device_id=" "
101+ card_path=" "
102+ if [ -n " ${gpu_map[$pci_id]} " ]; then
103+ # Get device id from /sys/class/drm map
104+ device_id=" ${gpu_map[$pci_id]} " # e.g. /dev/dri/card0
105+ # Reconstruct sysfs card path from the device path
106+ # device_id is /dev/dri/cardX, we want /sys/class/drm/cardX
107+ card_name=$( basename " $device_id " )
108+ card_path=" /sys/class/drm/$card_name "
109+ else
110+ # If it can't be found, default to pci id
111+ device_id=" ${pci_id} "
112+ fi
113+
114+ local devices=()
115+ local binds=()
116+ local cap_add=()
117+ local group_add=()
118+ local ipc_mode=" null"
119+ local shm_size=" null"
120+ local security_opt=" null"
121+
122+ # Only perform detailed checks if we found the card path
123+ if [ -n " $card_path " ] && [ -e " $card_path " ]; then
124+
125+ # Resolve real device path for getting sibling render node
126+ local real_device_path=$( readlink -f " $card_path /device" )
127+ local render_name=" "
128+ if [ -d " $real_device_path /drm" ]; then
129+ render_name=$( ls " $real_device_path /drm" | grep " ^renderD" | head -n 1)
130+ fi
131+
132+ case " $vendor_id_hex " in
133+ " 1002" ) # AMD (0x1002)
134+ # Devices
135+ [ -e " /dev/dxg" ] && devices+=(" /dev/dxg" )
136+ devices+=(" $device_id " ) # /dev/dri/cardX
137+
138+ # Binds
139+ [ -e " /usr/lib/wsl/lib/libdxcore.so" ] && \
140+ binds+=(" /usr/lib/wsl/lib/libdxcore.so:/usr/lib/libdxcore.so" )
141+ [ -e " /opt/rocm/lib/libhsa-runtime64.so.1" ] && \
142+ binds+=(" /opt/rocm/lib/libhsa-runtime64.so.1:/opt/rocm/lib/libhsa-runtime64.so.1" )
143+
144+ # Configs
145+ cap_add+=(" SYS_PTRACE" )
146+ ipc_mode=" \" host\" "
147+ shm_size=" 8589934592"
148+ # SecurityOpt is a JSON object
149+ security_opt=' {"seccomp": "unconfined"}'
150+ ;;
151+
152+ " 8086" ) # Intel (0x8086)
153+ # Devices
154+ [ -n " $render_name " ] && devices+=(" /dev/dri/$render_name " )
155+ devices+=(" $device_id " )
156+
157+ # Configs
158+ group_add+=(" video" " render" )
159+ cap_add+=(" SYS_ADMIN" )
160+ ;;
161+ esac
162+ else
163+ # Fallback if we don't have the card path mapped, but still want to add the primary device if applicable
164+ # This preserves behavior for devices that might not map correctly but are enumerated
165+ if [[ " $vendor_id_hex " == " 1002" ]] || [[ " $vendor_id_hex " == " 8086" ]]; then
166+ if [[ " $device_id " == /dev/* ]]; then
167+ devices+=(" $device_id " )
168+ fi
169+ fi
170+ fi
171+
172+
173+ # --- Construct JSON ---
174+
175+ # Helper to convert bash arrays using jq
176+ # (re-using the logic, but localized vars)
177+ json_devices=$( printf ' %s\n' " ${devices[@]} " | jq -R . | jq -s . | jq ' map(select(length > 0))' )
178+ json_binds=$( printf ' %s\n' " ${binds[@]} " | jq -R . | jq -s . | jq ' map(select(length > 0))' )
179+ json_cap=$( printf ' %s\n' " ${cap_add[@]} " | jq -R . | jq -s . | jq ' map(select(length > 0))' )
180+ json_group=$( printf ' %s\n' " ${group_add[@]} " | jq -R . | jq -s . | jq ' map(select(length > 0))' )
181+
182+ # If Devices array is empty, ensure at least the ID we found is there (unless it was already added)
183+ # Using 'index' to check if device_id is present is tricky with jq on the fly,
184+ # but standardizing on what we found is safer.
185+ # If the detailed logic above didn't populate devices (e.g. unknown vendor), we fall back to just the ID.
186+ if [ " $( echo " $json_devices " | jq length) " -eq 0 ]; then
187+ json_devices=" [\" $device_id \" ]"
66188 fi
67189
68- # Construct JSON
190+
69191 jq -c -n \
70192 --arg desc " $description " \
71193 --arg driver " $driver " \
72- --arg pci_id " $pci_id " \
194+ --arg device_id " $device_id " \
195+ --argjson dev " $json_devices " \
196+ --argjson bind " $json_binds " \
197+ --argjson cap " $json_cap " \
198+ --argjson group " $json_group " \
199+ --argjson sec " $security_opt " \
200+ --argjson shm " $shm_size " \
201+ --argjson ipc " $ipc_mode " \
73202 ' {
74203 description: $desc,
75204 init: {
76205 deviceRequests: {
77206 Driver: (if $driver != "" then $driver else null end),
78- DeviceIDs: [$pci_id]
79- }
207+ Devices: $dev,
208+ Capabilities: [["gpu"]]
209+ },
210+ Binds: $bind,
211+ CapAdd: $cap,
212+ GroupAdd: $group,
213+ SecurityOpt: $sec,
214+ ShmSize: $shm,
215+ IpcMode: $ipc
80216 }
81- }'
217+ } | del(.. | select(. == null)) | del(.. | select(. == [])) '
82218}
83219
84220# Function to get all GPUs in JSON array format
@@ -95,10 +231,16 @@ get_all_gpus_json() {
95231 init: {
96232 deviceRequests: {
97233 Driver: .[0].init.deviceRequests.Driver,
98- DeviceIDs: (map(.init.deviceRequests.DeviceIDs[] ) | unique),
234+ (if .[0].init.deviceRequests.Driver == "nvidia" then " DeviceIDs" else "Devices" end) : (map(.init.deviceRequests.Devices[]? ) | unique),
99235 Capabilities: [["gpu"]]
100- }
101- }
236+ },
237+ Binds: (map(.init.Binds[]?) | unique),
238+ CapAdd: (map(.init.CapAdd[]?) | unique),
239+ GroupAdd: (map(.init.GroupAdd[]?) | unique),
240+ SecurityOpt: .[0].init.SecurityOpt,
241+ ShmSize: .[0].init.ShmSize,
242+ IpcMode: .[0].init.IpcMode
243+ } | del(.. | select(. == null)) | del(.. | select(. == []))
102244 }) | map(if .init.deviceRequests.Driver == null then del(.init.deviceRequests.Driver) else . end)
103245 '
104246}
0 commit comments