@@ -5,6 +5,7 @@ use std::fs::{self};
55use std:: path:: Path ;
66use std:: process:: Command ;
77use std:: collections:: HashMap ;
8+ use utils:: os:: Platform ;
89
910// GPU Configuration - Because every GPU needs its marching orders! 🎮
1011#[ derive( Debug , Serialize , Deserialize , Clone ) ]
@@ -17,12 +18,16 @@ pub struct GPUConfig {
1718#[ derive( Debug , Serialize , Deserialize , Clone ) ]
1819pub struct GPUDevice {
1920 pub id : String , // Every star needs a unique name
20- pub vendor_id : String , // Who's your manufacturer? 🏭
21- pub device_id : String , // Model number - because we're all unique!
22- pub pci_address : String , // Where to find this beauty on the PCI runway
23- pub iommu_group : Option < String > , // The VIP lounge number (if we're fancy enough)
24- pub temperature : f64 , // Temperature of the GPU
25- pub utilization : f64 , // Utilization of the GPU
21+ pub vendor : String , // Who's your manufacturer? 🏭
22+ pub model : String , // Model number - because we're all unique!
23+ pub vram_mb : u64 , // VRAM in MB
24+ pub driver_version : String , // GPU driver version
25+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
26+ pub metal_support : Option < bool > ,
27+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
28+ pub vulkan_support : Option < bool > ,
29+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
30+ pub directx_version : Option < f32 > ,
2631}
2732
2833// The mastermind behind our GPU operations! 🧙♂️
@@ -38,30 +43,85 @@ impl GPUManager {
3843 } )
3944 }
4045
41- // Let's discover what GPUs are hiding in this machine! 🔍
42- pub fn discover_gpus ( & mut self ) -> Result < Vec < GPUDevice > > {
43- let mut devices = Vec :: new ( ) ;
44- let pci_devices = fs:: read_dir ( "/sys/bus/pci/devices" ) ?;
46+ /// Unified GPU detection across platforms
47+ pub fn detect_gpus ( & mut self ) -> Result < ( ) > {
48+ match Platform :: current ( ) {
49+ Platform :: Linux => self . detect_linux_gpus ( ) ,
50+ Platform :: MacOS => self . detect_macos_gpus ( ) ,
51+ Platform :: Windows => self . detect_windows_gpus ( ) ,
52+ _ => Err ( GpuError :: UnsupportedPlatform (
53+ "Unknown platform" . to_string ( )
54+ ) ) ,
55+ }
56+ }
57+
58+ #[ cfg( target_os = "linux" ) ]
59+ fn detect_linux_gpus ( & mut self ) -> Result < ( ) > {
60+ use nvml_wrapper:: Nvml ;
4561
46- for entry in pci_devices {
47- let path = entry?. path ( ) ;
48- let vendor = fs:: read_to_string ( path. join ( "vendor" ) ) ?;
49- let device = fs:: read_to_string ( path. join ( "device" ) ) ?;
50-
51- if is_gpu_device ( & vendor, & device) {
52- let iommu_group = get_iommu_group ( & path) ?;
53- devices. push ( GPUDevice {
54- id : format ! ( "{}:{}" , vendor. trim( ) , device. trim( ) ) ,
55- vendor_id : vendor. trim ( ) . to_string ( ) ,
56- device_id : device. trim ( ) . to_string ( ) ,
57- pci_address : path. file_name ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) . to_string ( ) ,
58- iommu_group,
59- temperature : read_gpu_temperature ( & path) ?,
60- utilization : read_gpu_utilization ( & path) ?,
62+ // NVIDIA detection
63+ if let Ok ( nvml) = Nvml :: init ( ) {
64+ for i in 0 ..nvml. device_count ( ) ? {
65+ let device = nvml. device_by_index ( i) ?;
66+ self . devices . push ( GPUDevice {
67+ id : device. uuid ( ) ?,
68+ vendor : "NVIDIA" . into ( ) ,
69+ model : device. name ( ) ?,
70+ vram_mb : device. memory_info ( ) ?. total / 1024 / 1024 ,
71+ driver_version : nvml. sys_driver_version ( ) ?,
72+ vulkan_support : Some ( true ) ,
73+ ..Default :: default ( )
6174 } ) ;
6275 }
6376 }
64- Ok ( devices)
77+
78+ // AMD detection (using amdgpu driver)
79+ // ... AMD detection logic ...
80+
81+ Ok ( ( ) )
82+ }
83+
84+ #[ cfg( target_os = "macos" ) ]
85+ fn detect_macos_gpus ( & mut self ) -> Result < ( ) > {
86+ use metal:: Device ;
87+
88+ for device in Device :: all ( ) {
89+ self . devices . push ( GPUDevice {
90+ id : device. registry_id ( ) . to_string ( ) ,
91+ vendor : "Apple" . into ( ) ,
92+ model : device. name ( ) . to_string ( ) ,
93+ vram_mb : device. recommended_max_vram ( ) / 1024 / 1024 ,
94+ metal_support : Some ( true ) ,
95+ ..Default :: default ( )
96+ } ) ;
97+ }
98+
99+ Ok ( ( ) )
100+ }
101+
102+ #[ cfg( target_os = "windows" ) ]
103+ fn detect_windows_gpus ( & mut self ) -> Result < ( ) > {
104+ use dxgi:: Factory ;
105+
106+ let factory = Factory :: new ( ) ?;
107+ for adapter in factory. adapters ( ) {
108+ let desc = adapter. get_desc ( ) ?;
109+ self . devices . push ( GPUDevice {
110+ id : format ! ( "PCI\\ VEN_{:04X}&DEV_{:04X}" , desc. vendor_id, desc. device_id) ,
111+ vendor : match desc. vendor_id {
112+ 0x10DE => "NVIDIA" . into ( ) ,
113+ 0x1002 => "AMD" . into ( ) ,
114+ 0x8086 => "Intel" . into ( ) ,
115+ _ => "Unknown" . into ( ) ,
116+ } ,
117+ model : desc. description . to_string ( ) ,
118+ vram_mb : ( desc. dedicated_video_memory / 1024 / 1024 ) as u64 ,
119+ directx_version : Some ( desc. revision as f32 / 10.0 ) ,
120+ ..Default :: default ( )
121+ } ) ;
122+ }
123+
124+ Ok ( ( ) )
65125 }
66126
67127 // Assign those GPUs to their IOMMU groups - like assigning students to classrooms!
@@ -82,7 +142,7 @@ impl GPUManager {
82142 // Match GPUs with their corresponding IOMMU groups
83143 for gpu in & mut self . devices {
84144 // Find IOMMU group from GPU's PCI address
85- let pci_path = Path :: new ( "/sys/bus/pci/devices" ) . join ( & gpu. pci_address ) ;
145+ let pci_path = Path :: new ( "/sys/bus/pci/devices" ) . join ( & gpu. id ) ;
86146 if let Some ( group) = get_iommu_group ( & pci_path) ? {
87147 // Collect all devices in the group
88148 let group_devices = iommu_groups. get ( & group)
@@ -209,7 +269,7 @@ fn get_gpu_info() -> Result<Vec<GPUDevice>> {
209269 for display in CGDisplay :: active_displays ( ) ? {
210270 gpus. push ( GPUDevice {
211271 id : format ! ( "display-{}" , display) ,
212- vendor_id : "Apple" . into ( ) ,
272+ vendor : "Apple" . into ( ) ,
213273 // MacOS specific GPU info
214274 } ) ;
215275 }
@@ -225,7 +285,7 @@ fn get_gpu_info() -> Result<Vec<GPUDevice>> {
225285 for adapter in factory. adapters ( ) {
226286 gpus. push ( GPUDevice {
227287 id : adapter. get_info ( ) . name ,
228- vendor_id : "NVIDIA/AMD/Intel" . into ( ) ,
288+ vendor : "NVIDIA/AMD/Intel" . into ( ) ,
229289 // Windows specific data
230290 } ) ;
231291 }
0 commit comments