Skip to content

Commit ff972c4

Browse files
authored
[ET-VK][ez] Add ability to check for dot product extension support + upgrade glslc (#13834)
## Motivation Prepare for shaders that will use accelerated int8 dot product GLSL extensions, i.e. `dotPacked4x8AccSatEXT` ## Changes * Query for support for the shader integer dot product extension when creating the VkPhysicalDevice * Request the shader integer dot product extension when creating VkDevice * Provide APIs to check if the extension is available in the current runtime. Differential Revision: [D81323427](https://our.internmc.facebook.com/intern/diff/D81323427/) [ghstack-poisoned]
1 parent ebdd12d commit ff972c4

File tree

14 files changed

+176
-5
lines changed

14 files changed

+176
-5
lines changed

.ci/scripts/setup-vulkan-linux-deps.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ install_vulkan_sdk() {
4343
export PATH="${PATH}:${_vulkan_sdk_dir}/${VULKAN_SDK_VERSION}/x86_64/bin/"
4444
}
4545

46-
VULKAN_SDK_VERSION="1.3.296.0"
46+
VULKAN_SDK_VERSION="1.4.321.1"
4747

4848
install_swiftshader
4949
install_vulkan_sdk "${VULKAN_SDK_VERSION}"

backends/vulkan/runtime/api/Context.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
111111
shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE);
112112
}
113113
}
114+
if (shader.requires_integer_dot_product) {
115+
if (!adapter_p_->supports_int8_dot_product()) {
116+
throw vkapi::ShaderNotSupportedError(
117+
shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT);
118+
}
119+
}
114120
}
115121

116122
vkapi::DescriptorSet Context::get_descriptor_set(

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,7 @@ class ShaderInfo:
11031103
requires_shader_int16_ext: bool = False
11041104
requires_16bit_storage_ext: bool = False
11051105
requires_8bit_storage_ext: bool = False
1106+
requires_integer_dot_product_ext: bool = False
11061107

11071108

11081109
def getName(filePath: str) -> str:
@@ -1213,6 +1214,8 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
12131214
shader_info.requires_16bit_storage_ext = True
12141215
if "GL_EXT_shader_8bit_storage" in line:
12151216
shader_info.requires_8bit_storage_ext = True
1217+
if "GL_EXT_integer_dot_product" in line:
1218+
shader_info.requires_integer_dot_product_ext = True
12161219

12171220
return shader_info
12181221

@@ -1288,6 +1291,7 @@ def to_cpp_str(val: bool):
12881291
to_cpp_str(shader_info.requires_shader_int16_ext),
12891292
to_cpp_str(shader_info.requires_16bit_storage_ext),
12901293
to_cpp_str(shader_info.requires_8bit_storage_ext),
1294+
to_cpp_str(shader_info.requires_integer_dot_product_ext),
12911295
]
12921296

12931297
shader_info_str = textwrap.indent(

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ ComputeGraph::ComputeGraph(GraphConfig config)
155155
config_.execute_threshold_node_count = 128;
156156
config_.execute_initial_threshold_node_count = 64;
157157
}
158+
159+
// Check if the underlying GPU can access accelerated integer dot product
160+
// instructions
161+
can_use_int8_dot_product_ =
162+
context_->adapter_ptr()->supports_int8_dot_product();
158163
}
159164

160165
ComputeGraph::~ComputeGraph() {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ class ComputeGraph final {
221221
// config.execute_threshold_node_count.
222222
size_t execute_threshold_node_count_ = 0;
223223

224+
// Whether the underlying GPU support accelerated integer dot product
225+
// extensions
226+
bool can_use_int8_dot_product_ = false;
227+
224228
public:
225229
//
226230
// Accessors
@@ -1013,6 +1017,10 @@ class ComputeGraph final {
10131017
return execute_count_;
10141018
}
10151019

1020+
inline bool can_use_int8_dot_product() const {
1021+
return can_use_int8_dot_product_;
1022+
}
1023+
10161024
/*
10171025
* Check whether the GPU supports 8 bit buffers.
10181026
*/

backends/vulkan/runtime/vk_api/Adapter.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ VkDevice create_logical_device(
109109
#ifdef VK_KHR_shader_float16_int8
110110
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME,
111111
#endif /* VK_KHR_shader_float16_int8 */
112+
#ifdef VK_KHR_shader_integer_dot_product
113+
VK_KHR_SHADER_INTEGER_DOT_PRODUCT_EXTENSION_NAME,
114+
#endif /* VK_KHR_shader_integer_dot_product */
112115
#if defined(VK_KHR_pipeline_executable_properties) && defined(VULKAN_DEBUG)
113116
VK_KHR_PIPELINE_EXECUTABLE_PROPERTIES_EXTENSION_NAME,
114117
#endif /* VK_KHR_pipeline_executable_properties */
@@ -160,6 +163,14 @@ VkDevice create_logical_device(
160163
extension_list_top = &shader_float16_int8_types;
161164
#endif /* VK_KHR_shader_float16_int8 */
162165

166+
#ifdef VK_KHR_shader_integer_dot_product
167+
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR
168+
shader_int_dot_product_features{
169+
physical_device.shader_int_dot_product_features};
170+
shader_int_dot_product_features.pNext = extension_list_top;
171+
extension_list_top = &shader_int_dot_product_features;
172+
#endif /* VK_KHR_shader_integer_dot_product */
173+
163174
device_create_info.pNext = extension_list_top;
164175

165176
VkDevice handle = nullptr;
@@ -401,6 +412,107 @@ std::string Adapter::stringize() const {
401412
#endif /* VK_KHR_shader_float16_int8 */
402413
ss << " }" << std::endl;
403414

415+
#ifdef VK_KHR_shader_integer_dot_product
416+
ss << " Shader Integer Dot Product Features {" << std::endl;
417+
PRINT_PROP(
418+
physical_device_.shader_int_dot_product_features,
419+
shaderIntegerDotProduct);
420+
ss << " }" << std::endl;
421+
422+
ss << " Shader Integer Dot Product Properties {" << std::endl;
423+
PRINT_PROP(
424+
physical_device_.shader_int_dot_product_properties,
425+
integerDotProduct8BitUnsignedAccelerated);
426+
PRINT_PROP(
427+
physical_device_.shader_int_dot_product_properties,
428+
integerDotProduct8BitSignedAccelerated);
429+
PRINT_PROP(
430+
physical_device_.shader_int_dot_product_properties,
431+
integerDotProduct8BitMixedSignednessAccelerated);
432+
PRINT_PROP(
433+
physical_device_.shader_int_dot_product_properties,
434+
integerDotProduct4x8BitPackedUnsignedAccelerated);
435+
PRINT_PROP(
436+
physical_device_.shader_int_dot_product_properties,
437+
integerDotProduct4x8BitPackedSignedAccelerated);
438+
PRINT_PROP(
439+
physical_device_.shader_int_dot_product_properties,
440+
integerDotProduct4x8BitPackedMixedSignednessAccelerated);
441+
PRINT_PROP(
442+
physical_device_.shader_int_dot_product_properties,
443+
integerDotProduct16BitUnsignedAccelerated);
444+
PRINT_PROP(
445+
physical_device_.shader_int_dot_product_properties,
446+
integerDotProduct16BitSignedAccelerated);
447+
PRINT_PROP(
448+
physical_device_.shader_int_dot_product_properties,
449+
integerDotProduct16BitMixedSignednessAccelerated);
450+
PRINT_PROP(
451+
physical_device_.shader_int_dot_product_properties,
452+
integerDotProduct32BitUnsignedAccelerated);
453+
PRINT_PROP(
454+
physical_device_.shader_int_dot_product_properties,
455+
integerDotProduct32BitSignedAccelerated);
456+
PRINT_PROP(
457+
physical_device_.shader_int_dot_product_properties,
458+
integerDotProduct32BitMixedSignednessAccelerated);
459+
PRINT_PROP(
460+
physical_device_.shader_int_dot_product_properties,
461+
integerDotProduct64BitUnsignedAccelerated);
462+
PRINT_PROP(
463+
physical_device_.shader_int_dot_product_properties,
464+
integerDotProduct64BitSignedAccelerated);
465+
PRINT_PROP(
466+
physical_device_.shader_int_dot_product_properties,
467+
integerDotProduct64BitMixedSignednessAccelerated);
468+
PRINT_PROP(
469+
physical_device_.shader_int_dot_product_properties,
470+
integerDotProductAccumulatingSaturating8BitUnsignedAccelerated);
471+
PRINT_PROP(
472+
physical_device_.shader_int_dot_product_properties,
473+
integerDotProductAccumulatingSaturating8BitSignedAccelerated);
474+
PRINT_PROP(
475+
physical_device_.shader_int_dot_product_properties,
476+
integerDotProductAccumulatingSaturating8BitMixedSignednessAccelerated);
477+
PRINT_PROP(
478+
physical_device_.shader_int_dot_product_properties,
479+
integerDotProductAccumulatingSaturating4x8BitPackedUnsignedAccelerated);
480+
PRINT_PROP(
481+
physical_device_.shader_int_dot_product_properties,
482+
integerDotProductAccumulatingSaturating4x8BitPackedSignedAccelerated);
483+
PRINT_PROP(
484+
physical_device_.shader_int_dot_product_properties,
485+
integerDotProductAccumulatingSaturating4x8BitPackedMixedSignednessAccelerated);
486+
PRINT_PROP(
487+
physical_device_.shader_int_dot_product_properties,
488+
integerDotProductAccumulatingSaturating16BitUnsignedAccelerated);
489+
PRINT_PROP(
490+
physical_device_.shader_int_dot_product_properties,
491+
integerDotProductAccumulatingSaturating16BitSignedAccelerated);
492+
PRINT_PROP(
493+
physical_device_.shader_int_dot_product_properties,
494+
integerDotProductAccumulatingSaturating16BitMixedSignednessAccelerated);
495+
PRINT_PROP(
496+
physical_device_.shader_int_dot_product_properties,
497+
integerDotProductAccumulatingSaturating32BitUnsignedAccelerated);
498+
PRINT_PROP(
499+
physical_device_.shader_int_dot_product_properties,
500+
integerDotProductAccumulatingSaturating32BitSignedAccelerated);
501+
PRINT_PROP(
502+
physical_device_.shader_int_dot_product_properties,
503+
integerDotProductAccumulatingSaturating32BitMixedSignednessAccelerated);
504+
PRINT_PROP(
505+
physical_device_.shader_int_dot_product_properties,
506+
integerDotProductAccumulatingSaturating64BitUnsignedAccelerated);
507+
PRINT_PROP(
508+
physical_device_.shader_int_dot_product_properties,
509+
integerDotProductAccumulatingSaturating64BitSignedAccelerated);
510+
PRINT_PROP(
511+
physical_device_.shader_int_dot_product_properties,
512+
integerDotProductAccumulatingSaturating64BitMixedSignednessAccelerated);
513+
ss << " }" << std::endl;
514+
#endif /* VK_KHR_shader_integer_dot_product */
515+
404516
const VkPhysicalDeviceMemoryProperties& mem_props =
405517
physical_device_.memory_properties;
406518

backends/vulkan/runtime/vk_api/Adapter.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,15 @@ class Adapter final {
212212
#endif /* VK_KHR_shader_float16_int8 */
213213
}
214214

215+
inline bool supports_int8_dot_product() {
216+
#ifdef VK_KHR_shader_integer_dot_product
217+
return physical_device_.shader_int_dot_product_features
218+
.shaderIntegerDotProduct == VK_TRUE;
219+
#else
220+
return false;
221+
#endif /* VK_KHR_shader_integer_dot_product */
222+
}
223+
215224
inline bool supports_int16_shader_types() {
216225
return physical_device_.supports_int16_shader_types;
217226
}

backends/vulkan/runtime/vk_api/Device.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
3636
shader_float16_int8_types{
3737
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR},
3838
#endif /* VK_KHR_shader_float16_int8 */
39+
#ifdef VK_KHR_shader_integer_dot_product
40+
shader_int_dot_product_features{
41+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR},
42+
shader_int_dot_product_properties{
43+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_PROPERTIES_KHR},
44+
#endif
3945
queue_families{},
4046
num_compute_queues(0),
4147
supports_int16_shader_types(false),
@@ -77,6 +83,13 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
7783
extension_list_top = &shader_float16_int8_types;
7884
#endif /* VK_KHR_shader_float16_int8 */
7985

86+
#ifdef VK_KHR_shader_integer_dot_product
87+
shader_int_dot_product_features.pNext = extension_list_top;
88+
extension_list_top = &shader_int_dot_product_features;
89+
shader_int_dot_product_properties.pNext = extension_list_top;
90+
extension_list_top = &shader_int_dot_product_properties;
91+
#endif /* VK_KHR_shader_integer_dot_product */
92+
8093
features2.pNext = extension_list_top;
8194

8295
vkGetPhysicalDeviceFeatures2(handle, &features2);

backends/vulkan/runtime/vk_api/Device.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ struct PhysicalDevice final {
4444
#ifdef VK_KHR_shader_float16_int8
4545
VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types;
4646
#endif /* VK_KHR_shader_float16_int8 */
47+
#ifdef VK_KHR_shader_integer_dot_product
48+
VkPhysicalDeviceShaderIntegerDotProductFeatures
49+
shader_int_dot_product_features;
50+
VkPhysicalDeviceShaderIntegerDotProductProperties
51+
shader_int_dot_product_properties;
52+
#endif /* VK_KHR_shader_integer_dot_product */
4753

4854
// Available GPU queues
4955
std::vector<VkQueueFamilyProperties> queue_families;

backends/vulkan/runtime/vk_api/Exception.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ std::ostream& operator<<(std::ostream& out, const VulkanExtension result) {
9292
case VulkanExtension::INT8_STORAGE:
9393
out << "VK_KHR_8bit_storage";
9494
break;
95+
case VulkanExtension::INTEGER_DOT_PRODUCT:
96+
out << "VK_KHR_shader_integer_dot_product";
97+
break;
9598
}
9699
return out;
97100
}

0 commit comments

Comments
 (0)