Skip to content

Commit b0f9a05

Browse files
committed
work on this
1 parent cb59ba9 commit b0f9a05

File tree

3 files changed

+367
-146
lines changed

3 files changed

+367
-146
lines changed

src/device.rs

Lines changed: 167 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
use ash::{
2-
vk::{self},
3-
Instance,
4-
};
1+
use ash::vk;
2+
use ash::Instance;
3+
use std::ffi::CStr;
54

65
use crate::vendor::Vendor;
76

7+
/// Represents a physical GPU device.
8+
#[derive(Debug)]
89
pub struct PhysicalDevice {
910
pub vendor: Vendor,
1011
pub device_name: String,
@@ -20,58 +21,73 @@ pub struct PhysicalDevice {
2021
pub characteristics: GPUCharacteristics,
2122
}
2223

24+
/// Contains various characteristics of a GPU.
25+
/// Vendor-specific properties are stored as Options.
26+
/// Also includes some general device limits.
27+
#[derive(Debug)]
2328
pub struct GPUCharacteristics {
29+
/// Memory pressure as computed from VRAM usage (0.0 to 1.0)
2430
pub memory_pressure: f32,
25-
pub compute_units: u32,
26-
pub shader_engines: u32,
27-
pub shader_arrays_per_engine_count: u32,
28-
pub compute_units_per_shader_array: u32,
29-
pub simd_per_compute_unit: u32,
30-
pub wavefronts_per_simd: u32,
31-
pub wavefront_size: u32,
32-
// Nvidia specific
31+
// AMD-specific properties.
32+
pub compute_units: Option<u32>,
33+
pub shader_engines: Option<u32>,
34+
pub shader_arrays_per_engine_count: Option<u32>,
35+
pub compute_units_per_shader_array: Option<u32>,
36+
pub simd_per_compute_unit: Option<u32>,
37+
pub wavefronts_per_simd: Option<u32>,
38+
pub wavefront_size: Option<u32>,
39+
// NVIDIA-specific properties.
3340
pub streaming_multiprocessors: Option<u32>,
3441
pub warps_per_sm: Option<u32>,
42+
// General device limits (useful for performance and capability queries).
43+
pub max_image_dimension_2d: u32,
44+
pub max_compute_shared_memory_size: u32,
45+
pub max_compute_work_group_invocations: u32,
3546
}
3647

3748
impl PhysicalDevice {
49+
/// Constructs a new `PhysicalDevice` by querying Vulkan properties.
3850
pub fn new(instance: &Instance, physical_device: vk::PhysicalDevice) -> Self {
51+
// Get the core properties and limits.
3952
let physical_device_properties: vk::PhysicalDeviceProperties =
4053
unsafe { instance.get_physical_device_properties(physical_device) };
54+
let limits = physical_device_properties.limits;
4155

56+
// Query additional driver properties.
4257
let mut driver_properties: vk::PhysicalDeviceDriverProperties =
4358
vk::PhysicalDeviceDriverProperties::default();
44-
4559
let mut properties2: vk::PhysicalDeviceProperties2 =
4660
vk::PhysicalDeviceProperties2::default().push_next(&mut driver_properties);
47-
4861
unsafe {
4962
instance.get_physical_device_properties2(physical_device, &mut properties2);
50-
};
63+
}
5164

5265
let vendor_id = physical_device_properties.vendor_id;
66+
let vendor = Vendor::from_vendor_id(vendor_id).unwrap_or_else(|| {
67+
eprintln!("Unknown vendor: {}", vendor_id);
68+
panic!();
69+
});
5370

54-
let vendor = match Vendor::from_vendor_id(vendor_id) {
55-
Some(v) => v,
56-
None => {
57-
eprintln!("Unknown vendor: {}", vendor_id);
58-
panic!();
59-
}
60-
};
61-
62-
let device_name =
63-
cstring_to_string(physical_device_properties.device_name_as_c_str().unwrap());
64-
71+
let device_name = cstring_to_string(
72+
physical_device_properties
73+
.device_name_as_c_str()
74+
.unwrap_or_else(|_| CStr::from_bytes_with_nul(b"Unknown\0").unwrap()),
75+
);
6576
let device_type = DeviceType::from(physical_device_properties.device_type.as_raw());
66-
6777
let device_id = physical_device_properties.device_id;
68-
6978
let api_version = decode_version_number(physical_device_properties.api_version);
79+
let driver_name = cstring_to_string(
80+
driver_properties
81+
.driver_name_as_c_str()
82+
.unwrap_or_else(|_| CStr::from_bytes_with_nul(b"Unknown\0").unwrap()),
83+
);
84+
let driver_info = cstring_to_string(
85+
driver_properties
86+
.driver_info_as_c_str()
87+
.unwrap_or_else(|_| CStr::from_bytes_with_nul(b"Unknown\0").unwrap()),
88+
);
7089

71-
let driver_name = cstring_to_string(driver_properties.driver_name_as_c_str().unwrap());
72-
73-
let driver_info = cstring_to_string(driver_properties.driver_info_as_c_str().unwrap());
74-
90+
// Query VRAM details.
7591
let mut memory_budget = vk::PhysicalDeviceMemoryBudgetPropertiesEXT::default();
7692
let mut memory_properties2 =
7793
vk::PhysicalDeviceMemoryProperties2::default().push_next(&mut memory_budget);
@@ -80,17 +96,13 @@ impl PhysicalDevice {
8096
.get_physical_device_memory_properties2(physical_device, &mut memory_properties2);
8197
}
8298
let memory_properties = memory_properties2.memory_properties;
83-
84-
// Determine VRAM heap index (first DEVICE_LOCAL heap)
8599
let vram_heap_index = (0..memory_properties.memory_heap_count)
86100
.find(|&i| {
87101
memory_properties.memory_heaps[i as usize]
88102
.flags
89103
.contains(vk::MemoryHeapFlags::DEVICE_LOCAL)
90104
})
91105
.unwrap_or(0);
92-
93-
// Compute heapsize, budget, and memory pressure
94106
let heapsize = memory_properties.memory_heaps[vram_heap_index as usize].size;
95107
let heapbudget = memory_budget.heap_budget[vram_heap_index as usize];
96108
let memory_pressure = if heapbudget > 0 {
@@ -99,8 +111,27 @@ impl PhysicalDevice {
99111
f32::NAN
100112
};
101113

102-
// Get vendor-specific characteristics.
103-
let characteristics = match vendor {
114+
// Initialize common GPUCharacteristics fields.
115+
let mut characteristics = GPUCharacteristics {
116+
memory_pressure,
117+
// Vendor-specific fields start as None.
118+
compute_units: None,
119+
shader_engines: None,
120+
shader_arrays_per_engine_count: None,
121+
compute_units_per_shader_array: None,
122+
simd_per_compute_unit: None,
123+
wavefronts_per_simd: None,
124+
wavefront_size: None,
125+
streaming_multiprocessors: None,
126+
warps_per_sm: None,
127+
// General limits:
128+
max_image_dimension_2d: limits.max_image_dimension2_d,
129+
max_compute_shared_memory_size: limits.max_compute_shared_memory_size,
130+
max_compute_work_group_invocations: limits.max_compute_work_group_invocations,
131+
};
132+
133+
// Query vendor-specific properties.
134+
match vendor {
104135
Vendor::AMD => {
105136
let mut shader_core_properties =
106137
vk::PhysicalDeviceShaderCorePropertiesAMD::default();
@@ -112,22 +143,21 @@ impl PhysicalDevice {
112143
unsafe {
113144
instance.get_physical_device_properties2(physical_device, &mut amd_properties2);
114145
}
115-
GPUCharacteristics {
116-
memory_pressure,
117-
compute_units: shader_core_properties.shader_engine_count
146+
characteristics.compute_units = Some(
147+
shader_core_properties.shader_engine_count
118148
* shader_core_properties.shader_arrays_per_engine_count
119149
* shader_core_properties.compute_units_per_shader_array,
120-
shader_engines: shader_core_properties.shader_engine_count,
121-
shader_arrays_per_engine_count: shader_core_properties
122-
.shader_arrays_per_engine_count,
123-
compute_units_per_shader_array: shader_core_properties
124-
.compute_units_per_shader_array,
125-
simd_per_compute_unit: shader_core_properties.simd_per_compute_unit,
126-
wavefronts_per_simd: shader_core_properties.wavefronts_per_simd,
127-
wavefront_size: shader_core_properties.wavefront_size,
128-
streaming_multiprocessors: None,
129-
warps_per_sm: None,
130-
}
150+
);
151+
characteristics.shader_engines = Some(shader_core_properties.shader_engine_count);
152+
characteristics.shader_arrays_per_engine_count =
153+
Some(shader_core_properties.shader_arrays_per_engine_count);
154+
characteristics.compute_units_per_shader_array =
155+
Some(shader_core_properties.compute_units_per_shader_array);
156+
characteristics.simd_per_compute_unit =
157+
Some(shader_core_properties.simd_per_compute_unit);
158+
characteristics.wavefronts_per_simd =
159+
Some(shader_core_properties.wavefronts_per_simd);
160+
characteristics.wavefront_size = Some(shader_core_properties.wavefront_size);
131161
}
132162
Vendor::Nvidia => {
133163
let mut sm_builtins = vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV::default();
@@ -136,32 +166,12 @@ impl PhysicalDevice {
136166
unsafe {
137167
instance.get_physical_device_properties2(physical_device, &mut nv_properties2);
138168
}
139-
GPUCharacteristics {
140-
memory_pressure,
141-
// For NVIDIA, AMD-specific values are not applicable.
142-
compute_units: 0,
143-
shader_engines: 0,
144-
shader_arrays_per_engine_count: 0,
145-
compute_units_per_shader_array: 0,
146-
simd_per_compute_unit: 0,
147-
wavefronts_per_simd: 0,
148-
wavefront_size: 0,
149-
streaming_multiprocessors: Some(sm_builtins.shader_sm_count),
150-
warps_per_sm: Some(sm_builtins.shader_warps_per_sm),
151-
}
169+
characteristics.streaming_multiprocessors = Some(sm_builtins.shader_sm_count);
170+
characteristics.warps_per_sm = Some(sm_builtins.shader_warps_per_sm);
171+
}
172+
_ => {
173+
// For other vendors, vendor-specific fields remain None.
152174
}
153-
_ => GPUCharacteristics {
154-
memory_pressure,
155-
compute_units: 0,
156-
shader_engines: 0,
157-
shader_arrays_per_engine_count: 0,
158-
compute_units_per_shader_array: 0,
159-
simd_per_compute_unit: 0,
160-
wavefronts_per_simd: 0,
161-
wavefront_size: 0,
162-
streaming_multiprocessors: None,
163-
warps_per_sm: None,
164-
},
165175
};
166176

167177
PhysicalDevice {
@@ -180,6 +190,8 @@ impl PhysicalDevice {
180190
}
181191
}
182192

193+
/// Represents the type of device.
194+
#[derive(Debug)]
183195
pub enum DeviceType {
184196
Other = 0,
185197
IntegratedGPU = 1,
@@ -190,6 +202,7 @@ pub enum DeviceType {
190202
}
191203

192204
impl DeviceType {
205+
/// Converts an integer ID (from Vulkan) into a DeviceType.
193206
pub fn from(id: i32) -> Self {
194207
match id {
195208
0 => DeviceType::Other,
@@ -201,6 +214,7 @@ impl DeviceType {
201214
}
202215
}
203216

217+
/// Returns a human‑readable name.
204218
pub fn name(&self) -> &'static str {
205219
match self {
206220
DeviceType::Other => "Other",
@@ -213,12 +227,7 @@ impl DeviceType {
213227
}
214228
}
215229

216-
/*
217-
The variant is a 3-bit integer packed into bits 31-29.
218-
The major version is a 7-bit integer packed into bits 28-22.
219-
The minor version number is a 10-bit integer packed into bits 21-12.
220-
The patch version number is a 12-bit integer packed into bits 11-0.
221-
*/
230+
/// Decodes a Vulkan version number into a string of the form "variant.major.minor.patch".
222231
pub fn decode_version_number(version: u32) -> String {
223232
let variant = (version >> 29) & 0b111;
224233
let major = (version >> 22) & 0b1111111;
@@ -227,6 +236,79 @@ pub fn decode_version_number(version: u32) -> String {
227236
format!("{}.{}.{}.{}", variant, major, minor, patch)
228237
}
229238

230-
pub fn cstring_to_string(cstr: &std::ffi::CStr) -> String {
231-
cstr.to_string_lossy().to_string()
239+
/// Converts a CStr to a Rust String.
240+
pub fn cstring_to_string(cstr: &CStr) -> String {
241+
cstr.to_string_lossy().into_owned()
242+
}
243+
244+
#[cfg(test)]
245+
mod tests {
246+
use super::*;
247+
use ash::vk;
248+
use std::ffi::CString;
249+
250+
// Helper to create a dummy CString.
251+
fn dummy_cstr(s: &str) -> CString {
252+
CString::new(s).unwrap()
253+
}
254+
255+
#[test]
256+
fn test_decode_version_number() {
257+
// Simulate a Vulkan version: variant 0, version 1.2.3
258+
let version: u32 = (0 << 29) | (1 << 22) | (2 << 12) | 3;
259+
let decoded = decode_version_number(version);
260+
assert_eq!(decoded, "0.1.2.3");
261+
}
262+
263+
#[test]
264+
fn test_cstring_to_string() {
265+
let original = "Hello, world!";
266+
let cstr = dummy_cstr(original);
267+
let s = cstring_to_string(cstr.as_c_str());
268+
assert_eq!(s, original);
269+
}
270+
271+
#[test]
272+
fn test_device_type_from() {
273+
assert_eq!(DeviceType::from(0).name(), "Other");
274+
assert_eq!(DeviceType::from(1).name(), "Integrated GPU");
275+
assert_eq!(DeviceType::from(2).name(), "Discrete GPU");
276+
assert_eq!(DeviceType::from(3).name(), "Virtual GPU");
277+
assert_eq!(DeviceType::from(4).name(), "CPU");
278+
assert_eq!(DeviceType::from(99).name(), "Unknown");
279+
}
280+
281+
#[test]
282+
fn test_gpu_characteristics_defaults() {
283+
// Create dummy limits.
284+
let limits = vk::PhysicalDeviceLimits {
285+
max_image_dimension2_d: 8192,
286+
max_compute_shared_memory_size: 16384,
287+
max_compute_work_group_invocations: 1024,
288+
// Other fields can use defaults.
289+
..Default::default()
290+
};
291+
292+
// Construct dummy GPUCharacteristics with only common limits.
293+
let characteristics = GPUCharacteristics {
294+
memory_pressure: 0.5,
295+
compute_units: None,
296+
shader_engines: None,
297+
shader_arrays_per_engine_count: None,
298+
compute_units_per_shader_array: None,
299+
simd_per_compute_unit: None,
300+
wavefronts_per_simd: None,
301+
wavefront_size: None,
302+
streaming_multiprocessors: None,
303+
warps_per_sm: None,
304+
max_image_dimension_2d: limits.max_image_dimension2_d,
305+
max_compute_shared_memory_size: limits.max_compute_shared_memory_size,
306+
max_compute_work_group_invocations: limits.max_compute_work_group_invocations,
307+
};
308+
309+
assert_eq!(characteristics.max_image_dimension_2d, 8192);
310+
assert_eq!(characteristics.max_compute_shared_memory_size, 16384);
311+
assert_eq!(characteristics.max_compute_work_group_invocations, 1024);
312+
assert!(characteristics.compute_units.is_none());
313+
}
232314
}

0 commit comments

Comments
 (0)