diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 9517941f364..f425859935d 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -87,6 +87,27 @@ void Context::report_shader_dispatch_end() { } } +void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { + if (shader.requires_shader_int16) { + if (!adapter_p_->supports_int16_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_INT16); + } + } + if (shader.requires_16bit_storage) { + if (!adapter_p_->supports_16bit_storage_buffers()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INT16_STORAGE); + } + } + if (shader.requires_8bit_storage) { + if (!adapter_p_->supports_8bit_storage_buffers()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE); + } + } +} + vkapi::DescriptorSet Context::get_descriptor_set( const vkapi::ShaderInfo& shader_descriptor, const utils::uvec3& local_workgroup_size, diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 300fd3995dd..0c199c24cc4 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -185,6 +185,8 @@ class Context final { } } + void check_device_capabilities(const vkapi::ShaderInfo& shader); + vkapi::DescriptorSet get_descriptor_set( const vkapi::ShaderInfo&, const utils::uvec3&, diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 7d004547a8e..7d3d2d52950 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -720,6 +720,10 @@ def maybe_replace_u16vecn(self, input_text: str) -> str: if "codegen-nosub" in input_text: return input_text + # Remove extension requirement so that generated ShaderInfo does not mark it + input_text = input_text.replace( + "#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require", "" + ) input_text = input_text.replace("u16vec", "ivec") input_text = input_text.replace("uint16_t", "int") return input_text @@ -791,6 +795,9 @@ class ShaderInfo: weight_storage_type: str = "" bias_storage_type: str = "" register_for: Optional[Tuple[str, List[str]]] = None + requires_shader_int16_ext: bool = False + requires_16bit_storage_ext: bool = False + requires_8bit_storage_ext: bool = False def getName(filePath: str) -> str: @@ -858,6 +865,11 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: return (matches_list[0], matches_list[1:]) +def isExtensionRequireLine(lineStr: str) -> bool: + extension_require_id = r"^#extension ([A-Za-z0-9_]+)\s*:\s*require" + return re.search(extension_require_id, lineStr) is not None + + typeIdMapping = { r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", @@ -889,6 +901,13 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info.bias_storage_type = getBiasStorageType(line) if isRegisterForLine(line): shader_info.register_for = findRegisterFor(line) + if isExtensionRequireLine(line): + if "GL_EXT_shader_explicit_arithmetic_types_int16" in line: + shader_info.requires_shader_int16_ext = True + if "GL_EXT_shader_16bit_storage" in line: + shader_info.requires_16bit_storage_ext = True + if "GL_EXT_shader_8bit_storage" in line: + shader_info.requires_8bit_storage_ext = True return shader_info @@ -952,12 +971,18 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts)) + def to_cpp_str(val: bool): + return "true" if val else "false" + shader_info_args = [ f'"{name}"', f"{name}_bin", str(sizeBytes), shader_info_layouts, tile_size, + to_cpp_str(shader_info.requires_shader_int16_ext), + to_cpp_str(shader_info.requires_16bit_storage_ext), + to_cpp_str(shader_info.requires_8bit_storage_ext), ] shader_info_str = textwrap.indent( diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index a163a0d7aea..63b8798f2c1 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -58,6 +58,8 @@ void DispatchNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); vkapi::PipelineBarrier pipeline_barrier{}; + context->check_device_capabilities(shader_); + std::unique_lock cmd_lock = context->dispatch_lock(); std::array push_constants_data; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 20fb9374bec..4a8d7418691 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -34,8 +34,6 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require - /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 5805d476a38..ec30650ba06 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -256,6 +256,9 @@ std::string Adapter::stringize() const { ss << " deviceType: " << device_type << std::endl; ss << " deviceName: " << properties.deviceName << std::endl; +#define PRINT_BOOL(value, name) \ + ss << " " << std::left << std::setw(36) << #name << value << std::endl; + #define PRINT_PROP(struct, name) \ ss << " " << std::left << std::setw(36) << #name << struct.name \ << std::endl; @@ -298,12 +301,13 @@ std::string Adapter::stringize() const { ss << " }" << std::endl; #endif /* VK_KHR_8bit_storage */ -#ifdef VK_KHR_shader_float16_int8 ss << " Shader 16bit and 8bit Features {" << std::endl; + PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16) +#ifdef VK_KHR_shader_float16_int8 PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16); PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8); - ss << " }" << std::endl; #endif /* VK_KHR_shader_float16_int8 */ + ss << " }" << std::endl; const VkPhysicalDeviceMemoryProperties& mem_props = physical_device_.memory_properties; diff --git a/backends/vulkan/runtime/vk_api/Exception.cpp b/backends/vulkan/runtime/vk_api/Exception.cpp index e330c1c079d..d26fbd8cb22 100644 --- a/backends/vulkan/runtime/vk_api/Exception.cpp +++ b/backends/vulkan/runtime/vk_api/Exception.cpp @@ -77,5 +77,36 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg) what_ = oss.str(); } +// +// ShaderNotSupportedError +// + +std::ostream& operator<<(std::ostream& out, const VulkanExtension result) { + switch (result) { + case VulkanExtension::SHADER_INT16: + out << "shaderInt16"; + break; + case VulkanExtension::INT16_STORAGE: + out << "VK_KHR_16bit_storage"; + break; + case VulkanExtension::INT8_STORAGE: + out << "VK_KHR_8bit_storage"; + break; + } + return out; +} + +ShaderNotSupportedError::ShaderNotSupportedError( + std::string shader_name, + VulkanExtension extension) + : shader_name_(std::move(shader_name)), extension_{extension} { + std::ostringstream oss; + oss << "Shader " << shader_name_ << " "; + oss << "not compatible with device. "; + oss << "Missing support for extension or physical device feature: "; + oss << extension_; + what_ = oss.str(); +} + } // namespace vkapi } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Exception.h b/backends/vulkan/runtime/vk_api/Exception.h index ec2f2956a88..a65afb1bcc5 100644 --- a/backends/vulkan/runtime/vk_api/Exception.h +++ b/backends/vulkan/runtime/vk_api/Exception.h @@ -78,5 +78,26 @@ class Error : public std::exception { } }; +enum class VulkanExtension : uint8_t { + SHADER_INT16, + INT16_STORAGE, + INT8_STORAGE, +}; + +class ShaderNotSupportedError : public std::exception { + public: + ShaderNotSupportedError(std::string shader_name, VulkanExtension extension); + + private: + std::string shader_name_; + VulkanExtension extension_; + std::string what_; + + public: + const char* what() const noexcept override { + return what_.c_str(); + } +}; + } // namespace vkapi } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Shader.cpp b/backends/vulkan/runtime/vk_api/Shader.cpp index 29774e2f404..e560f37868e 100644 --- a/backends/vulkan/runtime/vk_api/Shader.cpp +++ b/backends/vulkan/runtime/vk_api/Shader.cpp @@ -28,14 +28,20 @@ ShaderInfo::ShaderInfo( const uint32_t* const spirv_bin, const uint32_t size, std::vector layout, - const utils::uvec3 tile_size) + const utils::uvec3 tile_size, + const bool requires_shader_int16_ext, + const bool requires_16bit_storage_ext, + const bool requires_8bit_storage_ext) : src_code{ spirv_bin, size, }, kernel_name{std::move(name)}, kernel_layout{std::move(layout)}, - out_tile_size(tile_size) { + out_tile_size(tile_size), + requires_shader_int16(requires_shader_int16_ext), + requires_16bit_storage(requires_16bit_storage_ext), + requires_8bit_storage(requires_8bit_storage_ext) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index 1e3b2a799f2..d9fec65febc 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -62,6 +62,9 @@ struct ShaderInfo final { // Shader Metadata utils::uvec3 out_tile_size{1u, 1u, 1u}; + bool requires_shader_int16 = false; + bool requires_16bit_storage = false; + bool requires_8bit_storage = false; explicit ShaderInfo(); @@ -70,7 +73,10 @@ struct ShaderInfo final { const uint32_t*, const uint32_t, std::vector, - const utils::uvec3 tile_size); + const utils::uvec3 tile_size, + const bool requires_shader_int16_ext, + const bool requires_16bit_storage_ext, + const bool requires_8bit_storage_ext); operator bool() const { return src_code.bin != nullptr; diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 3d9aa6aa80b..d7e38969452 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -45,8 +45,13 @@ class GeneratedOpsTest_{op_name} : public ::testing::Test {{ test_suite_template = """ TEST_P(GeneratedOpsTest_{op_name}, {case_name}) {{ {create_ref_data} +try {{ {create_and_check_out} }} +catch (const vkcompute::vkapi::ShaderNotSupportedError& e) {{ + GTEST_SKIP() << e.what(); +}} +}} """