diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 5c59f13fc24..a137a7d538f 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -56,52 +56,97 @@ TYPE_MAPPINGS: Dict[str, Any] = { "IMAGE_T": { 3: { + "double": "image3D", "float": "image3D", "half": "image3D", - "int": "iimage3D", - "uint": "uimage3D", + # integer dtypes "int8": "iimage3D", "uint8": "uimage3D", + "int16": "iimage3D", + "uint16": "uimage3D", + "int32": "iimage3D", + "uint32": "uimage3D", + "int64": "iimage3D", + "uint64": "uimage3D", + # common dtype aliases "bool": "uimage3D", + "int": "iimage3D", + "uint": "uimage3D", }, 2: { + "double": "image2D", "float": "image2D", "half": "image2D", - "int": "iimage2D", - "uint": "uimage2D", + # integer dtypes "int8": "iimage2D", "uint8": "uimage2D", + "int16": "iimage2D", + "uint16": "uimage2D", + "int32": "iimage2D", + "uint32": "uimage2D", + "int64": "iimage2D", + "uint64": "uimage2D", + # common dtype aliases "bool": "uimage2D", + "int": "iimage2D", + "uint": "uimage2D", }, }, "SAMPLER_T": { 3: { + "double": "sampler3D", "float": "sampler3D", "half": "sampler3D", - "int": "isampler3D", - "uint": "usampler3D", + # integer dtypes "int8": "isampler3D", "uint8": "usampler3D", + "int16": "isampler3D", + "uint16": "usampler3D", + "int32": "isampler3D", + "uint32": "usampler3D", + "int64": "isampler3D", + "uint64": "usampler3D", + # common dtype aliases "bool": "usampler3D", + "int": "isampler3D", + "uint": "usampler3D", }, 2: { + "double": "sampler2D", "float": "sampler2D", "half": "sampler2D", - "int": "isampler2D", - "uint": "usampler2D", + # integer dtypes "int8": "isampler2D", "uint8": "usampler2D", + "int16": "isampler2D", + "uint16": "usampler2D", + "int32": "isampler2D", + "uint32": "usampler2D", + "int64": "isampler2D", + "uint64": "usampler2D", + # common dtype aliases "bool": "usampler2D", + "int": "isampler2D", + "uint": "usampler2D", }, }, "IMAGE_FORMAT": { + "double": "rgba32f", "float": "rgba32f", "half": "rgba16f", - "int": "rgba32i", - "uint": "rgba32ui", + # integer dtypes "int8": "rgba8i", "uint8": "rgba8ui", + "int16": "rgba16i", + "uint16": "rgba16ui", + "int32": "rgba32i", + "uint32": "rgba32ui", + "int64": "rgba32i", + "uint64": "rgba32ui", + # common dtype aliases "bool": "rgba8ui", + "int": "rgba32i", + "uint": "rgba32ui", }, } @@ -118,10 +163,18 @@ def define_variable(name: str) -> str: def buffer_scalar_type(dtype: str) -> str: if dtype == "half": return "float16_t" - elif dtype[-1] == "8": - return dtype + "_t" + elif dtype == "float": + return "float" + elif dtype == "double": + return "float64_t" + # integer dtype alias conversion elif dtype == "bool": return "uint8_t" + # we don't want to append _t for int32 or uint32 as int is already 32bit + elif dtype == "int32" or dtype == "uint32": + return "int" if dtype == "int32" else "uint" + elif dtype[-1].isdigit(): + return dtype + "_t" return dtype @@ -129,22 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str: if n == 1: return buffer_scalar_type(dtype) - if dtype == "float": - return f"vec{n}" - if dtype == "uint": - return f"uvec{n}" - elif dtype == "half": - return f"f16vec{n}" - elif dtype == "int": - return f"ivec{n}" - elif dtype == "int8": - return f"i8vec{n}" - elif dtype == "uint8": - return f"u8vec{n}" - elif dtype == "bool": - return f"u8vec{n}" - - raise AssertionError(f"Invalid dtype: {dtype}") + dtype_map = { + "half": f"f16vec{n}", + "float": f"vec{n}", + "double": f"vec{n}", # No 64bit image format support in GLSL + "int8": f"i8vec{n}", + "uint8": f"u8vec{n}", + "int16": f"i16vec{n}", + "uint16": f"u16vec{n}", + "int32": f"ivec{n}", + "int": f"ivec{n}", + "uint32": f"uvec{n}", + "uint": f"uvec{n}", + "int64": f"ivec{n}", # No 64bit image format support in GLSL + "uint64": f"uvec{n}", # No 64bit image format support in GLSL + "bool": f"u8vec{n}", + } + + vector_type = dtype_map.get(dtype) + if vector_type is None: + raise AssertionError(f"Invalid dtype: {dtype}") + + return vector_type def texel_type(dtype: str) -> str: @@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]): if dtype == "half": nbit = "16bit" glsl_type = "float16" - elif dtype == "int16" or dtype == "uint16": - nbit = "16bit" - glsl_type = "int16" - elif dtype == "int8" or dtype == "uint8" or dtype == "bool": + elif dtype == "double": + # We only need to allow float64_t type usage + glsl_type = "float64" + elif dtype in ["int8", "uint8", "bool"]: nbit = "8bit" glsl_type = "int8" + elif dtype in ["int16", "uint16"]: + nbit = "16bit" + glsl_type = "int16" + elif dtype in ["int64", "uint64"]: + # We only need to allow int64_t and uint64_t type usage + glsl_type = "int64" - if nbit is not None and glsl_type is not None: + if nbit is not None: out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" + if glsl_type is not None: out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" return out_str @@ -629,6 +695,10 @@ def generateVariantCombinations( elif "VALUE" in value: suffix = value.get("SUFFIX", value["VALUE"]) + if value["VALUE"] in ["int", "uint"]: + raise ValueError( + f"Use int32 or uint32 instead of {value['VALUE']}" + ) param_values.append((param_name, suffix, value["VALUE"])) else: diff --git a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml index e3df8bf73a1..37b2027db85 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml @@ -7,13 +7,13 @@ arange: parameter_names_with_default_values: NDIM: 3 - DTYPE: int + DTYPE: int32 STORAGE: texture3d PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: arange diff --git a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml index eddddec0d8d..b1e16dec8d6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml @@ -13,6 +13,6 @@ avg_pool2d: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: avg_pool2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c0efdd81eb9..accfcf53599 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -17,7 +17,7 @@ binary_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: binary_add - NAME: binary_sub diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml index 9abd9c1deac..e8bb86dbf6a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml @@ -12,8 +12,9 @@ buffer_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index e48eab63a64..679e686dc2f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -13,9 +13,10 @@ buffer_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_nchw - NAME: buffer_to_nchw_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml index 414bf8191b9..984d9a09d43 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml @@ -7,6 +7,6 @@ copy_channel_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_channel_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml index 87df7bf9dc1..09f5ca36ea4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml @@ -7,7 +7,7 @@ copy_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml index e872d64e3c3..6e55876cb28 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml @@ -7,6 +7,6 @@ copy_packed_dim_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_packed_dim_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml index 5ffe37265b1..0e7b491c433 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml @@ -7,6 +7,6 @@ embedding: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: embedding diff --git a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml index 646fd05e420..f5e7c874773 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml @@ -6,8 +6,9 @@ flip: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: flip diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 804ce19bdb8..646d8f1be81 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -14,9 +14,10 @@ image_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml index 5a6c525993e..abef2225cd9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -7,6 +7,6 @@ index_select: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml index 66cb7ec3f89..a306e3ce47d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -7,6 +7,6 @@ index_select_channel: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 486d710cf55..99e41a0ab6f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -13,9 +13,10 @@ nchw_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_buffer - NAME: nchw_to_buffer_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 4674822ce6a..f3f604e10cd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -87,5 +87,9 @@ void main() { return; } - write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); + $if DTYPE == "double" and DTYPE == "int64": + VEC4_T texel = read_texel(tidx); + write_texel(t_out, lpos_to_pos(lpos, axis_map), texel); + $else: + write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 7e52ec10376..85119c8d508 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -14,9 +14,10 @@ nchw_to_image: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index e64e1bd260a..bfeaba2496b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -12,7 +12,7 @@ no_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml index f678aeedf6e..a90ddcb41ce 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml @@ -7,6 +7,6 @@ permute: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: permute diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml index 526980a0f41..f40d94142e1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml @@ -7,7 +7,7 @@ repeat: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index f13393ce6c7..47f538aee6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -15,9 +15,9 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) - - NAME: clamp_int + - NAME: clamp_int32 OPERATOR: clamp(X, A, B) - DTYPE: int + DTYPE: int32 - NAME: cos OPERATOR: cos(X) - NAME: exp diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index ba11a2496a0..33364a25225 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -7,6 +7,6 @@ view: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: view diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index e1ac4e9d40a..6388a8ad091 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -34,24 +34,42 @@ void add_storage_type_suffix( void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { switch (dtype) { + case vkapi::kDouble: + kernel_name += "_double"; + break; case vkapi::kFloat: kernel_name += "_float"; break; case vkapi::kHalf: kernel_name += "_half"; break; - case vkapi::kInt: - kernel_name += "_int"; - break; case vkapi::kChar: case vkapi::kQInt8: kernel_name += "_int8"; break; case vkapi::kByte: - case vkapi::kQUInt8: case vkapi::kBool: + case vkapi::kQUInt8: kernel_name += "_uint8"; break; + case vkapi::kShort: + kernel_name += "_int16"; + break; + case vkapi::kUInt16: + kernel_name += "_uint16"; + break; + case vkapi::kInt: + kernel_name += "_int32"; + break; + case vkapi::kUInt: + kernel_name += "_uint32"; + break; + case vkapi::kLong: + kernel_name += "_int64"; + break; + case vkapi::kUInt64: + kernel_name += "_uint64"; + break; default: break; } diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h index f25fe95d72b..b3309aa6c69 100644 --- a/backends/vulkan/runtime/vk_api/Types.h +++ b/backends/vulkan/runtime/vk_api/Types.h @@ -30,11 +30,17 @@ #define VK_FORALL_SCALAR_TYPES(_) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ - _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ - _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \ + _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ + _(uint16_t, VK_FORMAT_R16G16B16A16_UINT, UInt16) \ + _(int16_t, VK_FORMAT_R16G16B16A16_SINT, Short) \ + _(uint32_t, VK_FORMAT_R32G32B32A32_UINT, UInt) \ + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ + _(uint64_t, VK_FORMAT_R64G64B64A64_UINT, UInt64) \ + _(int64_t, VK_FORMAT_R64G64B64A64_SINT, Long) \ _(float, VK_FORMAT_FLOAT4, Float) \ + _(double, VK_FORMAT_R64G64B64A64_SFLOAT, Double) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) @@ -86,17 +92,29 @@ inline VkFormat to_vkformat(const ScalarType t) { */ inline ScalarType element_scalartype(const VkFormat vkformat) { switch (vkformat) { + case VK_FORMAT_R64G64B64A64_SFLOAT: + return kDouble; + case VK_FORMAT_R32G32B32A32_SFLOAT: + return kFloat; + case VK_FORMAT_R16G16B16A16_SFLOAT: + return kHalf; case VK_FORMAT_R8G8B8A8_SINT: return kChar; case VK_FORMAT_R8G8B8A8_UINT: case VK_FORMAT_R8G8B8A8_UNORM: return kByte; + case VK_FORMAT_R16G16B16A16_SINT: + return kShort; + case VK_FORMAT_R16G16B16A16_UINT: + return kUInt16; case VK_FORMAT_R32G32B32A32_SINT: return kInt; - case VK_FORMAT_R32G32B32A32_SFLOAT: - return kFloat; - case VK_FORMAT_R16G16B16A16_SFLOAT: - return kHalf; + case VK_FORMAT_R32G32B32A32_UINT: + return kUInt; + case VK_FORMAT_R64G64B64A64_SINT: + return kLong; + case VK_FORMAT_R64G64B64A64_UINT: + return kUInt64; default: VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat); } diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml index 37403c97ac8..4ef934eb105 100644 --- a/backends/vulkan/test/glsl/all_shaders.yaml +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -51,7 +51,7 @@ idx_fill_texture: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 shader_variants: - NAME: idx_fill_texture diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py index 65bb959f6d1..a054fdf1a19 100644 --- a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py @@ -177,6 +177,8 @@ def generate_benchmark_fixture(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {{ switch (at_scalartype) {{ + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: @@ -187,6 +189,8 @@ def generate_benchmark_fixture(self) -> str: return vkapi::kInt; case c10::kChar: return vkapi::kChar; + case c10::kBool: + return vkapi::kBool; default: VK_THROW("Unsupported at::ScalarType!"); }} diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index 4f0d2ff11ef..e7cf5ba92a5 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -110,6 +110,8 @@ def gen_parameterization(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { switch (at_scalartype) { + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: diff --git a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml index a00bba2bc5a..69587bd38d0 100644 --- a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml +++ b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml @@ -6,7 +6,7 @@ warp_size: parameter_names_with_default_values: - DTYPE: int + DTYPE: int32 STORAGE: buffer generate_variant_forall: METHOD: