From 2fdccb30817351f8fb1f9bb28663ddf395ed9474 Mon Sep 17 00:00:00 2001 From: Andrea Nicastro Date: Fri, 30 May 2025 01:03:49 -0700 Subject: [PATCH] Where layer (#11181) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11181 The `where` layer was missing and this diff adds it along with the bool tensor support. Reviewed By: SS-JIA Differential Revision: D74175287 --- backends/vulkan/runtime/gen_vulkan_spv.py | 18 ++- .../graph/ops/glsl/buffer_to_nchw.yaml | 1 + .../runtime/graph/ops/glsl/image_to_nchw.yaml | 1 + .../graph/ops/glsl/nchw_to_buffer.yaml | 1 + .../runtime/graph/ops/glsl/nchw_to_image.yaml | 1 + .../vulkan/runtime/graph/ops/glsl/where.glsl | 111 +++++++++++++++ .../vulkan/runtime/graph/ops/glsl/where.yaml | 12 ++ .../vulkan/runtime/graph/ops/impl/Where.cpp | 126 ++++++++++++++++++ .../graph/ops/utils/ShaderNameUtils.cpp | 1 + backends/vulkan/runtime/vk_api/Types.h | 2 +- backends/vulkan/test/op_tests/cases.py | 25 ++++ .../op_tests/utils/gen_correctness_base.py | 8 +- .../test/op_tests/utils/gen_correctness_vk.py | 2 + 13 files changed, 301 insertions(+), 8 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/where.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/where.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Where.cpp diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 2b15b2b7d0a..5c59f13fc24 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -62,6 +62,7 @@ "uint": "uimage3D", "int8": "iimage3D", "uint8": "uimage3D", + "bool": "uimage3D", }, 2: { "float": "image2D", @@ -70,6 +71,7 @@ "uint": "uimage2D", "int8": "iimage2D", "uint8": "uimage2D", + "bool": "uimage2D", }, }, "SAMPLER_T": { @@ -80,6 +82,7 @@ "uint": "usampler3D", "int8": "isampler3D", "uint8": "usampler3D", + "bool": "usampler3D", }, 2: { "float": "sampler2D", @@ -88,6 +91,7 @@ "uint": "usampler2D", "int8": "isampler2D", "uint8": "usampler2D", + "bool": "usampler2D", }, }, "IMAGE_FORMAT": { @@ -97,6 +101,7 @@ "uint": "rgba32ui", "int8": "rgba8i", "uint8": "rgba8ui", + "bool": "rgba8ui", }, } @@ -115,7 +120,8 @@ def buffer_scalar_type(dtype: str) -> str: return "float16_t" elif dtype[-1] == "8": return dtype + "_t" - + elif dtype == "bool": + return "uint8_t" return dtype @@ -135,17 +141,19 @@ def buffer_gvec_type(dtype: str, n: int) -> str: return f"i8vec{n}" elif dtype == "uint8": return f"u8vec{n}" + elif dtype == "bool": + return f"u8vec{n}" raise AssertionError(f"Invalid dtype: {dtype}") def texel_type(dtype: str) -> str: image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype] - if image_format[-1] == "f": + if image_format[-1:] == "f": return "vec4" - elif image_format[-2] == "ui": + elif image_format[-2:] == "ui": return "uvec4" - elif image_format[-1] == "i": + elif image_format[-1:] == "i": return "ivec4" raise AssertionError(f"Invalid image format: {image_format}") @@ -360,7 +368,7 @@ def define_required_extensions(dtypes: Union[str, List[str]]): elif dtype == "int16" or dtype == "uint16": nbit = "16bit" glsl_type = "int16" - elif dtype == "int8" or dtype == "uint8": + elif dtype == "int8" or dtype == "uint8" or dtype == "bool": nbit = "8bit" glsl_type = "int8" diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index 653bda9ccc0..25b3657c2eb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -14,5 +14,6 @@ buffer_to_nchw: - VALUE: float - VALUE: int - VALUE: int8 + - VALUE: uint8 shader_variants: - NAME: buffer_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 8fc9340d9d0..c1045d93afc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -15,6 +15,7 @@ image_to_nchw: - VALUE: float - VALUE: int - VALUE: int8 + - VALUE: uint8 shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 6292ef93337..a85c1ec6c65 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -14,5 +14,6 @@ nchw_to_buffer: - VALUE: float - VALUE: int - VALUE: int8 + - VALUE: uint8 shader_variants: - NAME: nchw_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index f44e1f74bfe..9d17ff5f645 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -15,6 +15,7 @@ nchw_to_image: - VALUE: float - VALUE: int - VALUE: int8 + - VALUE: uint8 shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/where.glsl b/backends/vulkan/runtime/graph/ops/glsl/where.glsl new file mode 100644 index 00000000000..5df813d1241 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/where.glsl @@ -0,0 +1,111 @@ +// where.glsl + +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${buffer_scalar_type(DTYPE)} +#define COND_T ${buffer_scalar_type("bool")} + +${define_active_storage_type(STORAGE)} +${define_required_extensions(DTYPE)} +${define_required_extensions("bool")} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_condition", "bool", STORAGE)} +${layout_declare_tensor(B, "r", "t_self", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)} + + +#include "indexing_utils.h" + +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "int", "out_numl")} + ${layout_declare_ubo(B, "ivec4", "out_strides")} + ${layout_declare_ubo(B, "ivec4", "cond_strides")} + ${layout_declare_ubo(B, "ivec4", "self_strides")} + ${layout_declare_ubo(B, "ivec4", "other_strides")} + + ${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")} + ${layout_declare_spec_const(C, "int", "cond_packed_dim", "DEFAULT_LAYOUT")} + ${layout_declare_spec_const(C, "int", "self_packed_dim", "DEFAULT_LAYOUT")} + ${layout_declare_spec_const(C, "int", "other_packed_dim", "DEFAULT_LAYOUT")} +$else: + ${layout_declare_ubo(B, "ivec3", "out_limits")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#ifdef USING_BUFFER + +void main() { + int out_bufi = int(gl_GlobalInvocationID.x); + // ivec4 tidx = ivec4(gl_GlobalInvocationID, 0); + // int out_bufi = tidx_to_bufi(tidx, out_strides); + // int cond_bufi = tidx_to_bufi(tidx, cond_strides); + // int self_bufi = tidx_to_bufi(tidx, self_strides); + // int other_bufi = tidx_to_bufi(tidx, other_strides); + if (out_bufi >= out_numl) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim); + out_bufi = tidx_to_bufi(out_tidx, out_strides); + + const ivec4 cond_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim); + const int cond_bufi = tidx_to_bufi(cond_tidx, cond_strides); + + const ivec4 self_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim); + const int self_bufi = tidx_to_bufi(self_tidx, self_strides); + + const ivec4 other_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim); + const int other_bufi = tidx_to_bufi(other_tidx, other_strides); + + COND_T cond = t_condition[cond_bufi] ; + T v_self = t_self[self_bufi]; + T v_other = t_other[other_bufi]; + + if (cond > 0) { + t_out[out_bufi] = v_self; + } else { + t_out[out_bufi] = v_other; + } +} + +#else // !USING_BUFFER + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + vec4 cond = load_texel(t_condition, pos); + VEC4_T selftex = load_texel(t_self, pos); + VEC4_T othertex = load_texel(t_other, pos); + + VEC4_T outtex; + + for (int idx = 0; idx < 4; ++idx) { + if (cond[idx] == 1) { + outtex[idx] = selftex[idx]; + } else { + outtex[idx] = othertex[idx]; + } + } + write_texel(t_out, pos, outtex); +} + #endif // !USING_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/where.yaml b/backends/vulkan/runtime/graph/ops/glsl/where.yaml new file mode 100644 index 00000000000..edbd843a336 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/where.yaml @@ -0,0 +1,12 @@ +where: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: where diff --git a/backends/vulkan/runtime/graph/ops/impl/Where.cpp b/backends/vulkan/runtime/graph/ops/impl/Where.cpp new file mode 100644 index 00000000000..a3be34830d3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Where.cpp @@ -0,0 +1,126 @@ +// Where.cpp + +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace vkcompute { + +void resize_where_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + std::vector in_sizes = in->sizes(); + out->virtual_resize(in_sizes); +} + +void add_where_texture_node( + ComputeGraph& graph, + const ValueRef cond, + const ValueRef self, + const ValueRef other, + const ValueRef out) { + std::string kernel_name = "where"; + + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const utils::uvec3 global_wg_size = graph.create_global_wg_size(out); + const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + // Shader + VK_KERNEL_FROM_STR(kernel_name), + // Workgroup sizes + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}}, + // Parameter buffers + {graph.logical_limits_ubo(self)}, + // Push Constants + {}, + // Specialization Constants + {graph.packed_dim_of(out)}, + // Resize Arguments + {}, + // Resizing Logic + resize_where_node)); +} + +void add_where_buffer_node( + ComputeGraph& graph, + const ValueRef cond, + const ValueRef self, + const ValueRef other, + const ValueRef out) { + std::string kernel_name = "where"; + + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const utils::uvec3 global_wg_size = graph.create_global_wg_size(out); + const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + vkapi::ParamsBindList ubos = { + graph.numel_ubo(out), + graph.strides_ubo(out), + graph.strides_ubo(cond), + graph.strides_ubo(self), + graph.strides_ubo(other)}; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + // Shader + VK_KERNEL_FROM_STR(kernel_name), + // Workgroup sizes + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}}, + // Parameter buffers + ubos, + // Push Constants + {}, + // Specialization Constants + {graph.packed_dim_of(out), + graph.packed_dim_of(cond), + graph.packed_dim_of(self), + graph.packed_dim_of(other)}, + // Resize Arguments + {}, + // Resizing Logic + resize_where_node)); +} + +void where(ComputeGraph& graph, const std::vector& args) { + int args_i = 0; + const ValueRef cond = args[args_i++]; + const ValueRef self = args[args_i++]; + const ValueRef other = args[args_i++]; + const ValueRef out = args[args_i++]; + if (graph.is_buffer_storage(out)) { + add_where_buffer_node(graph, cond, self, other, out); + } else { + add_where_texture_node(graph, cond, self, other, out); + } +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.where.self, where); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index 469c2ed8280..e1ac4e9d40a 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -49,6 +49,7 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { break; case vkapi::kByte: case vkapi::kQUInt8: + case vkapi::kBool: kernel_name += "_uint8"; break; default: diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h index 7191409c215..6531bf4710c 100644 --- a/backends/vulkan/runtime/vk_api/Types.h +++ b/backends/vulkan/runtime/vk_api/Types.h @@ -27,7 +27,7 @@ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ - _(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \ + _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \ _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ _(float, VK_FORMAT_FLOAT4, Float) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 4a12f16bbf9..fc45187ea10 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1349,3 +1349,28 @@ def get_flip_inputs(): test_suite = VkTestSuite([tuple(tc) for tc in test_cases]) return test_suite + + +@register_test_suite("aten.where.self") +def get_where_inputs(): + Test = namedtuple("Where", ["condition", "self", "other"]) + Test.__new__.__defaults__ = (None, None, None) + + test_cases = [ + Test(condition=[11], self=[11], other=[11]), + Test(condition=[10, 9], self=[10, 9], other=[10, 9]), + Test(condition=[10, 5, 3], self=[10, 5, 3], other=[10, 5, 3]), + Test(condition=[2, 10, 5, 3], self=[2, 10, 5, 3], other=[2, 10, 5, 3]), + ] + + test_suite = VkTestSuite([tuple(tc) for tc in test_cases]) + test_suite.arg_dtype["condition"] = "at::kBool" + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kHeightPacked", + "utils::kChannelsPacked", + ] + test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] + test_suite.atol = "1e-4" + test_suite.rtol = "1e-4" + return test_suite diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index e6ce135736b..5be4ddba6bf 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -282,12 +282,16 @@ def generate_suite_cpp(self) -> str: at::ScalarType dtype = at::kFloat, float low = 0.0, float high = 1.0) {{ - if (high == 1.0 && low == 0.0) - return at::rand(sizes, at::device(at::kCPU).dtype(dtype)); if (dtype == at::kChar) return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype)); + if (dtype == at::kBool) + return at::rand(sizes, at::device(at::kCPU)) > 0.5; + + if (high == 1.0 && low == 0.0) + return at::rand(sizes, at::device(at::kCPU).dtype(dtype)); + return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low; }} diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index 6c165a777db..ce6ab32ce60 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -119,6 +119,8 @@ def gen_parameterization(self) -> str: return vkapi::kInt; case c10::kChar: return vkapi::kChar; + case c10::kBool: + return vkapi::kBool; default: VK_THROW("Unsupported at::ScalarType!"); }