From 5543f3fc0e8d67748209a389a388350222c1f57f Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 22 Apr 2025 09:42:35 -0700 Subject: [PATCH 1/2] [ET-VK][ez] Streamline + fix enabling device extensions Pull Request resolved: https://github.com/pytorch/executorch/pull/10352 ## Changes Simplify the logic to construct a linked list of physical device feature structs When constructing the logical device, construct the linked list of device features to enable instead of using a stored pointer from the constructor of `PhysicalDevice`. The reason is that due to possible moves, the stored pointer may be invalid by the time the `Adapter` instance is constructed. It is safer to reconstruct the device features linked list at the time it is needed. ghstack-source-id: 279563505 @exported-using-ghexport Differential Revision: [D73438721](https://our.internmc.facebook.com/intern/diff/D73438721/) --- backends/vulkan/runtime/vk_api/Adapter.cpp | 28 +++++++++++++++- backends/vulkan/runtime/vk_api/Device.cpp | 37 ++++++++-------------- backends/vulkan/runtime/vk_api/Device.h | 3 -- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index db6fdc2909a..5b698096bb5 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -107,7 +107,33 @@ VkDevice create_logical_device( nullptr, // pEnabledFeatures }; - device_create_info.pNext = physical_device.extension_features; + void* extension_list_top = nullptr; + +#ifdef VK_KHR_16bit_storage + VkPhysicalDevice16BitStorageFeatures shader_16bit_storage{ + physical_device.shader_16bit_storage}; + + shader_16bit_storage.pNext = extension_list_top; + extension_list_top = &shader_16bit_storage; +#endif /* VK_KHR_16bit_storage */ + +#ifdef VK_KHR_8bit_storage + VkPhysicalDevice8BitStorageFeatures shader_8bit_storage{ + physical_device.shader_8bit_storage}; + + shader_8bit_storage.pNext = extension_list_top; + extension_list_top = &shader_8bit_storage; +#endif /* VK_KHR_8bit_storage */ + +#ifdef VK_KHR_shader_float16_int8 + VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types{ + physical_device.shader_float16_int8_types}; + + shader_float16_int8_types.pNext = extension_list_top; + extension_list_top = &shader_float16_int8_types; +#endif /* VK_KHR_shader_float16_int8 */ + + device_create_info.pNext = extension_list_top; VkDevice handle = nullptr; VK_CHECK(vkCreateDevice( diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index c4119e04b78..e87510cd9f8 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -34,7 +34,6 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) shader_float16_int8_types{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR}, #endif /* VK_KHR_shader_float16_int8 */ - extension_features{nullptr}, queue_families{}, num_compute_queues(0), supports_int16_shader_types(false), @@ -57,34 +56,24 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) // Create linked list to query availability of extensions + void* extension_list_top = nullptr; + #ifdef VK_KHR_16bit_storage - extension_features = &shader_16bit_storage; - features2.pNext = &shader_16bit_storage; -#elif defined(VK_KHR_8bit_storage) - extension_features = &shader_8bit_storage; - features2.pNext = &shader_8bit_storage; -#elif defined(VK_KHR_shader_float16_int8) - extension_features = &shader_float16_int8_types; - features2.pNext = &shader_float16_int8_types; + shader_16bit_storage.pNext = extension_list_top; + extension_list_top = &shader_16bit_storage; #endif /* VK_KHR_16bit_storage */ -#if defined(VK_KHR_16bit_storage) && defined(VK_KHR_8bit_storage) - shader_16bit_storage.pNext = &shader_8bit_storage; -#elif defined(VK_KHR_16bit_storage) && defined(VK_KHR_shader_float16_int8) - shader_16bit_storage.pNext = &shader_float16_int8_types; -#elif defined(VK_KHR_16bit_storage) - shader_16bit_storage.pNext = nullptr; -#endif - -#if defined(VK_KHR_8bit_storage) && defined(VK_KHR_shader_float16_int8) - shader_8bit_storage.pNext = &shader_float16_int8_types; -#elif defined(VK_KHR_8bit_storage) - shader_8bit_storage.pNext = nullptr; -#endif +#ifdef VK_KHR_8bit_storage + shader_8bit_storage.pNext = extension_list_top; + extension_list_top = &shader_8bit_storage; +#endif /* VK_KHR_8bit_storage */ #ifdef VK_KHR_shader_float16_int8 - shader_float16_int8_types.pNext = nullptr; -#endif + shader_float16_int8_types.pNext = extension_list_top; + extension_list_top = &shader_float16_int8_types; +#endif /* VK_KHR_shader_float16_int8 */ + + features2.pNext = extension_list_top; vkGetPhysicalDeviceFeatures2(handle, &features2); diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index 70d5b1db5af..6bb17b01223 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -37,9 +37,6 @@ struct PhysicalDevice final { VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types; #endif /* VK_KHR_shader_float16_int8 */ - // Head of the linked list of extensions to be requested - void* extension_features; - // Available GPU queues std::vector queue_families; From 15e7111495342366d1bb0833a205164e2cc767de Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 22 Apr 2025 09:42:36 -0700 Subject: [PATCH 2/2] [ET-VK][ez] Store physical device identity metadata Pull Request resolved: https://github.com/pytorch/executorch/pull/10353 ## Context Lay the groundwork for storing device identity metadata. This will allow operator implementations to select and/or configure compute shaders based on the GPU that is being used. Differential Revision: [D73438723](https://our.internmc.facebook.com/intern/diff/D73438723/) ghstack-source-id: 279563504 --- backends/vulkan/runtime/vk_api/Device.cpp | 24 ++++++++++++++++++++++- backends/vulkan/runtime/vk_api/Device.h | 12 ++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index e87510cd9f8..b9e3b444db2 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -12,7 +12,9 @@ #include +#include #include +#include #include namespace vkcompute { @@ -40,7 +42,9 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) has_unified_memory(false), has_timestamps(false), timestamp_period(0), - min_ubo_alignment(0) { + min_ubo_alignment(0), + device_name{}, + device_type{DeviceType::UNKNOWN} { // Extract physical device properties vkGetPhysicalDeviceProperties(handle, &properties); @@ -107,6 +111,24 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) num_compute_queues += p.queueCount; } } + + // Obtain device identity metadata + device_name = std::string(properties.deviceName); + std::transform( + device_name.begin(), + device_name.end(), + device_name.begin(), + [](unsigned char c) { return std::tolower(c); }); + + if (device_name.find("adreno") != std::string::npos) { + device_type = DeviceType::ADRENO; + } else if (device_name.find("swiftshader") != std::string::npos) { + device_type = DeviceType::SWIFTSHADER; + } else if (device_name.find("nvidia") != std::string::npos) { + device_type = DeviceType::NVIDIA; + } else if (device_name.find("mali") != std::string::npos) { + device_type = DeviceType::MALI; + } } // diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index 6bb17b01223..3fdfcc04a49 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -18,6 +18,14 @@ namespace vkcompute { namespace vkapi { +enum class DeviceType : uint32_t { + UNKNOWN, + NVIDIA, + MALI, + ADRENO, + SWIFTSHADER, +}; + struct PhysicalDevice final { // Handle VkPhysicalDevice handle; @@ -48,6 +56,10 @@ struct PhysicalDevice final { float timestamp_period; size_t min_ubo_alignment; + // Device identity + std::string device_name; + DeviceType device_type; + explicit PhysicalDevice(VkPhysicalDevice); };