Skip to content

Commit afff37c

Browse files
authored
[ET-VK] Parse required extensions of shaders and check capabilities during dispatch
Differential Revision: D67992067 Pull Request resolved: #7576
1 parent e1c0bcf commit afff37c

File tree

11 files changed

+128
-7
lines changed

11 files changed

+128
-7
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,27 @@ void Context::report_shader_dispatch_end() {
8787
}
8888
}
8989

90+
void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
91+
if (shader.requires_shader_int16) {
92+
if (!adapter_p_->supports_int16_shader_types()) {
93+
throw vkapi::ShaderNotSupportedError(
94+
shader.kernel_name, vkapi::VulkanExtension::SHADER_INT16);
95+
}
96+
}
97+
if (shader.requires_16bit_storage) {
98+
if (!adapter_p_->supports_16bit_storage_buffers()) {
99+
throw vkapi::ShaderNotSupportedError(
100+
shader.kernel_name, vkapi::VulkanExtension::INT16_STORAGE);
101+
}
102+
}
103+
if (shader.requires_8bit_storage) {
104+
if (!adapter_p_->supports_8bit_storage_buffers()) {
105+
throw vkapi::ShaderNotSupportedError(
106+
shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE);
107+
}
108+
}
109+
}
110+
90111
vkapi::DescriptorSet Context::get_descriptor_set(
91112
const vkapi::ShaderInfo& shader_descriptor,
92113
const utils::uvec3& local_workgroup_size,

backends/vulkan/runtime/api/Context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ class Context final {
185185
}
186186
}
187187

188+
void check_device_capabilities(const vkapi::ShaderInfo& shader);
189+
188190
vkapi::DescriptorSet get_descriptor_set(
189191
const vkapi::ShaderInfo&,
190192
const utils::uvec3&,

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,10 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
720720
if "codegen-nosub" in input_text:
721721
return input_text
722722

723+
# Remove extension requirement so that generated ShaderInfo does not mark it
724+
input_text = input_text.replace(
725+
"#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require", ""
726+
)
723727
input_text = input_text.replace("u16vec", "ivec")
724728
input_text = input_text.replace("uint16_t", "int")
725729
return input_text
@@ -791,6 +795,9 @@ class ShaderInfo:
791795
weight_storage_type: str = ""
792796
bias_storage_type: str = ""
793797
register_for: Optional[Tuple[str, List[str]]] = None
798+
requires_shader_int16_ext: bool = False
799+
requires_16bit_storage_ext: bool = False
800+
requires_8bit_storage_ext: bool = False
794801

795802

796803
def getName(filePath: str) -> str:
@@ -858,6 +865,11 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
858865
return (matches_list[0], matches_list[1:])
859866

860867

868+
def isExtensionRequireLine(lineStr: str) -> bool:
869+
extension_require_id = r"^#extension ([A-Za-z0-9_]+)\s*:\s*require"
870+
return re.search(extension_require_id, lineStr) is not None
871+
872+
861873
typeIdMapping = {
862874
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
863875
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
@@ -889,6 +901,13 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
889901
shader_info.bias_storage_type = getBiasStorageType(line)
890902
if isRegisterForLine(line):
891903
shader_info.register_for = findRegisterFor(line)
904+
if isExtensionRequireLine(line):
905+
if "GL_EXT_shader_explicit_arithmetic_types_int16" in line:
906+
shader_info.requires_shader_int16_ext = True
907+
if "GL_EXT_shader_16bit_storage" in line:
908+
shader_info.requires_16bit_storage_ext = True
909+
if "GL_EXT_shader_8bit_storage" in line:
910+
shader_info.requires_8bit_storage_ext = True
892911

893912
return shader_info
894913

@@ -952,12 +971,18 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) ->
952971

953972
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
954973

974+
def to_cpp_str(val: bool):
975+
return "true" if val else "false"
976+
955977
shader_info_args = [
956978
f'"{name}"',
957979
f"{name}_bin",
958980
str(sizeBytes),
959981
shader_info_layouts,
960982
tile_size,
983+
to_cpp_str(shader_info.requires_shader_int16_ext),
984+
to_cpp_str(shader_info.requires_16bit_storage_ext),
985+
to_cpp_str(shader_info.requires_8bit_storage_ext),
961986
]
962987

963988
shader_info_str = textwrap.indent(

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ void DispatchNode::encode(ComputeGraph* graph) {
5858
api::Context* const context = graph->context();
5959
vkapi::PipelineBarrier pipeline_barrier{};
6060

61+
context->check_device_capabilities(shader_);
62+
6163
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
6264

6365
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3434

3535
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3636

37-
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
38-
3937
/*
4038
* Computes a depthwise convolution. Each shader invocation calculates the
4139
* output at a single output location.

backends/vulkan/runtime/vk_api/Adapter.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ std::string Adapter::stringize() const {
256256
ss << " deviceType: " << device_type << std::endl;
257257
ss << " deviceName: " << properties.deviceName << std::endl;
258258

259+
#define PRINT_BOOL(value, name) \
260+
ss << " " << std::left << std::setw(36) << #name << value << std::endl;
261+
259262
#define PRINT_PROP(struct, name) \
260263
ss << " " << std::left << std::setw(36) << #name << struct.name \
261264
<< std::endl;
@@ -298,12 +301,13 @@ std::string Adapter::stringize() const {
298301
ss << " }" << std::endl;
299302
#endif /* VK_KHR_8bit_storage */
300303

301-
#ifdef VK_KHR_shader_float16_int8
302304
ss << " Shader 16bit and 8bit Features {" << std::endl;
305+
PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16)
306+
#ifdef VK_KHR_shader_float16_int8
303307
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
304308
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
305-
ss << " }" << std::endl;
306309
#endif /* VK_KHR_shader_float16_int8 */
310+
ss << " }" << std::endl;
307311

308312
const VkPhysicalDeviceMemoryProperties& mem_props =
309313
physical_device_.memory_properties;

backends/vulkan/runtime/vk_api/Exception.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,36 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg)
7777
what_ = oss.str();
7878
}
7979

80+
//
81+
// ShaderNotSupportedError
82+
//
83+
84+
std::ostream& operator<<(std::ostream& out, const VulkanExtension result) {
85+
switch (result) {
86+
case VulkanExtension::SHADER_INT16:
87+
out << "shaderInt16";
88+
break;
89+
case VulkanExtension::INT16_STORAGE:
90+
out << "VK_KHR_16bit_storage";
91+
break;
92+
case VulkanExtension::INT8_STORAGE:
93+
out << "VK_KHR_8bit_storage";
94+
break;
95+
}
96+
return out;
97+
}
98+
99+
ShaderNotSupportedError::ShaderNotSupportedError(
100+
std::string shader_name,
101+
VulkanExtension extension)
102+
: shader_name_(std::move(shader_name)), extension_{extension} {
103+
std::ostringstream oss;
104+
oss << "Shader " << shader_name_ << " ";
105+
oss << "not compatible with device. ";
106+
oss << "Missing support for extension or physical device feature: ";
107+
oss << extension_;
108+
what_ = oss.str();
109+
}
110+
80111
} // namespace vkapi
81112
} // namespace vkcompute

backends/vulkan/runtime/vk_api/Exception.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,26 @@ class Error : public std::exception {
7878
}
7979
};
8080

81+
enum class VulkanExtension : uint8_t {
82+
SHADER_INT16,
83+
INT16_STORAGE,
84+
INT8_STORAGE,
85+
};
86+
87+
class ShaderNotSupportedError : public std::exception {
88+
public:
89+
ShaderNotSupportedError(std::string shader_name, VulkanExtension extension);
90+
91+
private:
92+
std::string shader_name_;
93+
VulkanExtension extension_;
94+
std::string what_;
95+
96+
public:
97+
const char* what() const noexcept override {
98+
return what_.c_str();
99+
}
100+
};
101+
81102
} // namespace vkapi
82103
} // namespace vkcompute

backends/vulkan/runtime/vk_api/Shader.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,20 @@ ShaderInfo::ShaderInfo(
2828
const uint32_t* const spirv_bin,
2929
const uint32_t size,
3030
std::vector<VkDescriptorType> layout,
31-
const utils::uvec3 tile_size)
31+
const utils::uvec3 tile_size,
32+
const bool requires_shader_int16_ext,
33+
const bool requires_16bit_storage_ext,
34+
const bool requires_8bit_storage_ext)
3235
: src_code{
3336
spirv_bin,
3437
size,
3538
},
3639
kernel_name{std::move(name)},
3740
kernel_layout{std::move(layout)},
38-
out_tile_size(tile_size) {
41+
out_tile_size(tile_size),
42+
requires_shader_int16(requires_shader_int16_ext),
43+
requires_16bit_storage(requires_16bit_storage_ext),
44+
requires_8bit_storage(requires_8bit_storage_ext) {
3945
}
4046

4147
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {

backends/vulkan/runtime/vk_api/Shader.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ struct ShaderInfo final {
6262

6363
// Shader Metadata
6464
utils::uvec3 out_tile_size{1u, 1u, 1u};
65+
bool requires_shader_int16 = false;
66+
bool requires_16bit_storage = false;
67+
bool requires_8bit_storage = false;
6568

6669
explicit ShaderInfo();
6770

@@ -70,7 +73,10 @@ struct ShaderInfo final {
7073
const uint32_t*,
7174
const uint32_t,
7275
std::vector<VkDescriptorType>,
73-
const utils::uvec3 tile_size);
76+
const utils::uvec3 tile_size,
77+
const bool requires_shader_int16_ext,
78+
const bool requires_16bit_storage_ext,
79+
const bool requires_8bit_storage_ext);
7480

7581
operator bool() const {
7682
return src_code.bin != nullptr;

0 commit comments

Comments
 (0)