diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index fff530d57cb..f40a6b0f286 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -155,6 +155,11 @@ ComputeGraph::ComputeGraph(GraphConfig config) config_.execute_threshold_node_count = 128; config_.execute_initial_threshold_node_count = 64; } + + // Check if the underlying GPU can access accelerated integer dot product + // instructions + can_use_int8_dot_product_ = + context_->adapter_ptr()->supports_int8_dot_product(); } ComputeGraph::~ComputeGraph() { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 4257f63fab6..78fb79e65e8 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -221,6 +221,10 @@ class ComputeGraph final { // config.execute_threshold_node_count. size_t execute_threshold_node_count_ = 0; + // Whether the underlying GPU support accelerated integer dot product + // extensions + bool can_use_int8_dot_product_ = false; + public: // // Accessors @@ -1013,6 +1017,10 @@ class ComputeGraph final { return execute_count_; } + inline bool can_use_int8_dot_product() const { + return can_use_int8_dot_product_; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index e08491c656b..5f939f564a3 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -160,6 +160,14 @@ VkDevice create_logical_device( extension_list_top = &shader_float16_int8_types; #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR + shader_int_dot_product_features{ + physical_device.shader_int_dot_product_features}; + shader_int_dot_product_features.pNext = extension_list_top; + extension_list_top = &shader_int_dot_product_features; +#endif /* VK_KHR_shader_integer_dot_product */ + device_create_info.pNext = extension_list_top; VkDevice handle = nullptr; @@ -401,6 +409,107 @@ std::string Adapter::stringize() const { #endif /* VK_KHR_shader_float16_int8 */ ss << " }" << std::endl; +#ifdef VK_KHR_shader_integer_dot_product + ss << " Shader Integer Dot Product Features {" << std::endl; + PRINT_PROP( + physical_device_.shader_int_dot_product_features, + shaderIntegerDotProduct); + ss << " }" << std::endl; + + ss << " Shader Integer Dot Product Properties {" << std::endl; + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct8BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct8BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct8BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct4x8BitPackedUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct4x8BitPackedSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct4x8BitPackedMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct16BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct16BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct16BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct32BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct32BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct32BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct64BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct64BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProduct64BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating8BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating8BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating8BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating4x8BitPackedUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating4x8BitPackedSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating4x8BitPackedMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating16BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating16BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating16BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating32BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating32BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating32BitMixedSignednessAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating64BitUnsignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating64BitSignedAccelerated); + PRINT_PROP( + physical_device_.shader_int_dot_product_properties, + integerDotProductAccumulatingSaturating64BitMixedSignednessAccelerated); + ss << " }" << std::endl; +#endif /* VK_KHR_shader_integer_dot_product */ + const VkPhysicalDeviceMemoryProperties& mem_props = physical_device_.memory_properties; diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index aa4c659c6d8..6a68b487348 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -212,6 +212,15 @@ class Adapter final { #endif /* VK_KHR_shader_float16_int8 */ } + inline bool supports_int8_dot_product() { +#ifdef VK_KHR_shader_integer_dot_product + return physical_device_.shader_int_dot_product_features + .shaderIntegerDotProduct == VK_TRUE; +#else + return false; +#endif /* VK_KHR_shader_integer_dot_product */ + } + inline bool supports_int16_shader_types() { return physical_device_.supports_int16_shader_types; } diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index b9e3b444db2..a21130f1231 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -36,6 +36,12 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) shader_float16_int8_types{ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR}, #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + shader_int_dot_product_features{ + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR}, + shader_int_dot_product_properties{ + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_PROPERTIES_KHR}, +#endif queue_families{}, num_compute_queues(0), supports_int16_shader_types(false), @@ -77,6 +83,13 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) extension_list_top = &shader_float16_int8_types; #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + shader_int_dot_product_features.pNext = extension_list_top; + extension_list_top = &shader_int_dot_product_features; + shader_int_dot_product_properties.pNext = extension_list_top; + extension_list_top = &shader_int_dot_product_properties; +#endif /* VK_KHR_shader_integer_dot_product */ + features2.pNext = extension_list_top; vkGetPhysicalDeviceFeatures2(handle, &features2); diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index 3fdfcc04a49..f5b7154d260 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -44,6 +44,12 @@ struct PhysicalDevice final { #ifdef VK_KHR_shader_float16_int8 VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types; #endif /* VK_KHR_shader_float16_int8 */ +#ifdef VK_KHR_shader_integer_dot_product + VkPhysicalDeviceShaderIntegerDotProductFeatures + shader_int_dot_product_features; + VkPhysicalDeviceShaderIntegerDotProductProperties + shader_int_dot_product_properties; +#endif /* VK_KHR_shader_integer_dot_product */ // Available GPU queues std::vector queue_families; diff --git a/backends/vulkan/runtime/vk_api/QueryPool.cpp b/backends/vulkan/runtime/vk_api/QueryPool.cpp index 2f6d433b887..e8b3ca55206 100644 --- a/backends/vulkan/runtime/vk_api/QueryPool.cpp +++ b/backends/vulkan/runtime/vk_api/QueryPool.cpp @@ -209,7 +209,7 @@ std::string QueryPool::generate_string_report() { std::stringstream ss; - int kernel_name_w = 40; + int kernel_name_w = 120; int global_size_w = 25; int local_size_w = 25; int duration_w = 25;