From 83715d2e2f47d321f081d9e765e99dfb24a081d5 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 8 Aug 2025 12:47:44 -0700 Subject: [PATCH 1/2] [ET-VK] Allow `aten.cat.default` to handle any number of input tensors Pull Request resolved: https://github.com/pytorch/executorch/pull/13226 ## Context Previously, I updated the implementation of `aten.cat.default` in D76305343 (https://github.com/pytorch/executorch/pull/11508) since the original implementation had a bug. The new implementation only supported up to 3 input tensors, but several models require the need for up to 6 input tensors. This diff updates the capabilities of the `concat` op so that any arbitrary number of input tensors may be accepted. ## Changes * Update implementation of the concat shader to be able to be called repeatedly, allowing support for any number of input tensors. Differential Revision: [D79893084](https://our.internmc.facebook.com/intern/diff/D79893084/) ghstack-source-id: 301766080 --- backends/vulkan/op_registry.py | 8 - .../vulkan/runtime/api/containers/Tensor.cpp | 9 +- .../vulkan/runtime/graph/ops/ExecuteNode.h | 2 +- .../runtime/graph/ops/glsl/concat_buffer.glsl | 61 +++- .../graph/ops/glsl/concat_texture.glsl | 193 ++++++---- .../runtime/graph/ops/glsl/concat_utils.glslh | 33 ++ .../runtime/graph/ops/glsl/indexing_utils.h | 6 + .../runtime/graph/ops/glsl/set_zero.glsl | 33 ++ .../runtime/graph/ops/glsl/set_zero.yaml | 8 + .../graph/ops/glsl/update_concat_offset.glsl | 42 +++ .../graph/ops/glsl/update_concat_offset.yaml | 13 + .../vulkan/runtime/graph/ops/impl/Concat.cpp | 338 ++++++++++++++---- backends/vulkan/test/op_tests/cases.py | 135 +++---- backends/vulkan/test/test_vulkan_delegate.py | 144 ++++++++ 14 files changed, 790 insertions(+), 235 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/concat_utils.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/set_zero.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/set_zero.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.yaml diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 22a93ec0e2b..e3498cf1792 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -491,17 +491,9 @@ def register_view_ops(): # for both texture and buffer storage types. @update_features(exir_ops.edge.aten.cat.default) def register_cat_op(): - def check_cat_node(node: torch.fx.Node) -> bool: - inputs = node.args[0] - if isinstance(inputs, (list, tuple)) and len(inputs) <= 3: - return True - - return False - return OpFeatures( inputs_storage=utils.ANY_STORAGE, supports_resize=True, - are_node_inputs_supported_fn=check_cat_node, ) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 64f330de59c..a3d9bd4aa34 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -517,6 +517,7 @@ void vTensorStorage::transition( vkapi::MemoryAccessFlags prev_access = last_access_.access; const bool prev_written = (prev_access & vkapi::MemoryAccessType::WRITE) != 0; + const bool cur_written = (cur_access & vkapi::MemoryAccessType::WRITE) != 0; VkImageLayout cur_layout = VK_IMAGE_LAYOUT_UNDEFINED; VkImageLayout new_layout = VK_IMAGE_LAYOUT_UNDEFINED; @@ -528,7 +529,13 @@ void vTensorStorage::transition( layout_changed = cur_layout != new_layout; } - if (prev_written || layout_changed) { + // RAW: need to make sure current read sees previous writes + // WAW: need to make sure the current write occurs after previous write so + // the final value is correct. + // WAR: need to make sure previous read does not read the value from the + // current write. + // RAR: no need for synchronization + if (prev_written || cur_written || layout_changed) { VkPipelineStageFlags src_stage = vkapi::vk_stage(prev_stage); if (0u == src_stage) { src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 6a815b246ef..4ea1ba57796 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -43,7 +43,7 @@ class ExecuteNode { friend class ComputeGraph; public: - using ResizeFunction = const std::function&, const std::vector&)>; diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl index 895cecb413a..e34ecaf8309 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl @@ -20,10 +20,12 @@ layout(std430) buffer; #include "indexing_utils.h" -${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "rw", "t_out", DTYPE, "buffer")} $for i in range(NUM_INPUTS): - ${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_concat_offset", "int", "buffer")} ${layout_declare_ubo(B, "int", "concat_dim")} @@ -31,8 +33,8 @@ ${layout_declare_ubo(B, "ivec4", "out_sizes")} ${layout_declare_ubo(B, "ivec4", "out_strides")} $for i in range(NUM_INPUTS): - ${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")} - ${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_strides")} + ${layout_declare_ubo(B, "ivec4", "inp" + str(i) + "_sizes")} + ${layout_declare_ubo(B, "ivec4", "inp" + str(i) + "_strides")} ${layout_declare_ubo(B, "int", "out_numel")} @@ -42,28 +44,53 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#define NUM_INPUTS ${NUM_INPUTS} + +#include "concat_utils.glslh" + +/* + * This shader template concatenates up to NUM_INPUT input tensors to the + * output tensor along the concat_dim. Elements from the input tensor will + * be inserted along the output's concat_dim starting at concat_offset. + */ void main() { - const int out_bufi = ivec3(gl_GlobalInvocationID).x; - if (out_bufi >= out_numel) { + const int tid = ivec3(gl_GlobalInvocationID).x; + + // The 1-3 input tensors are interpreted as one concatenated tensor ("volume") + // along the concat_dim for the purposes of tensor indexing. Each thread is + // responsible for reading one item from this volume and writing it to the + // appropriate output location. + ivec4 inp_volume_sizes = out_sizes; + inp_volume_sizes[concat_dim] = total_concat_dim_numel(); + + // Account for 0 size input tensors + if (any(lessThanEqual(inp_volume_sizes, ivec4(0)))) { + return; + } + + ivec4 inp_volume_tidx = nchwi_to_tidx(tid, inp_volume_sizes); + + // bounds check + if (any(greaterThanEqual(inp_volume_tidx, inp_volume_sizes))) { return; } - // Convert buffer linear index to 4-D tensor index for output - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); + int concat_offset = t_concat_offset[0]; + + ivec4 out_tidx = inp_volume_tidx; + out_tidx[concat_dim] += concat_offset; - // Determine which input tensor to read from - ivec4 in_tidx = out_tidx; + const uint out_bufi = tidx_to_bufi(out_tidx, out_strides); + // Go through the list of input tensors, and find which input this output + // element should be read from. $for i in range(NUM_INPUTS): - // Check if the index at the concat dim is within bounds of the input tensor - // If so, read from that input tensor and write to output - if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { - int in_bufi = tidx_to_bufi(in_tidx, in${i+1}_strides); - t_out[out_bufi] = t_in${i+1}[in_bufi]; + if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) { + int inp_bufi = tidx_to_bufi(inp_volume_tidx, inp${i}_strides); + t_out[out_bufi] = t_inp${i}[inp_bufi]; return; } - // otherwise, decrement the index at the concat dim else { - in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; + inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim]; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl index dac6266bf67..afab0c524d6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl @@ -19,16 +19,18 @@ layout(std430) buffer; #include "indexing_utils.h" -${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "rw", "t_out", DTYPE, "texture3d")} $for i in range(NUM_INPUTS): - ${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "texture3d")} + ${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "texture3d")} + +${layout_declare_tensor(B, "r", "t_concat_offset", "int", "buffer")} ${layout_declare_ubo(B, "int", "concat_dim")} $in_metadata = "" $for i in range(NUM_INPUTS): - $in_metadata += "ivec4 in" + str(i + 1) + "_sizes;\n" + $in_metadata += "ivec4 inp" + str(i) + "_sizes;\n" layout(push_constant) uniform restrict Block { ivec4 out_sizes; @@ -40,90 +42,135 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); const lowp int out_packed_dim = unhash_packed_dim(out_layout); $for i in range(NUM_INPUTS): - ${layout_declare_spec_const(C, "int", "in" + str(i+1) + "_layout", "DEFAULT_LAYOUT")} - const lowp ivec4 in${i+1}_axis_map = unhash_axis_map(in${i+1}_layout); - const lowp int in${i+1}_packed_dim = unhash_packed_dim(in${i+1}_layout); + ${layout_declare_spec_const(C, "int", "inp" + str(i) + "_layout", "DEFAULT_LAYOUT")} + const lowp ivec4 inp${i}_axis_map = unhash_axis_map(inp${i}_layout); + const lowp int inp${i}_packed_dim = unhash_packed_dim(inp${i}_layout); layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// Check if we can use the fast path (no texel merging required) -bool can_use_fast_path() { - // Fast path is possible when: - // 1. The concat dimension is not the packed dimension, or - // 2. The concat dimension is the packed dimension but both input tensors have dimensions - // that are multiples of 4 along the packed dimension - if (concat_dim != out_packed_dim) { - return true; - } - - // Check if all input tensors have dimensions that are multiples of 4 along the packed dimension - bool all_concat_dim_size_multiple_of_4 = true; - $for i in range(NUM_INPUTS): - all_concat_dim_size_multiple_of_4 = - all_concat_dim_size_multiple_of_4 && - (in${i+1}_sizes[concat_dim] % 4 == 0); +#define NUM_INPUTS ${NUM_INPUTS} - return all_concat_dim_size_multiple_of_4; -} +#include "concat_utils.glslh" +/* + * This shader template concatenates up to NUM_INPUT input tensors to the + * output tensor along the concat_dim. Elements from the input tensor will + * be inserted along the output's concat_dim starting at concat_offset. + * + * Each thread is responsible for writing out one output texel. The data + * required for the output texel may be read from multiple input texels of one + * input tensor. + */ void main() { - const ivec3 lpos = ivec3(gl_GlobalInvocationID); - ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim); - - if (any(greaterThanEqual(out_tidx, out_sizes))) { + const int tid = ivec3(gl_GlobalInvocationID).x; + + // Sum of the sizes of all input tensors along the concat_dim + const int concat_numel = total_concat_dim_numel(); + + // The 1-3 input tensors are interpreted as one concatenated tensor ("volume") + // along the concat_dim for the purposes of tensor indexing. Each thread is + // responsible for writing out 4 elements along the packed dim of the output + // tensor by reading the source data from the input tensor(s). + ivec4 inp_volume_sizes = out_sizes; + inp_volume_sizes[concat_dim] = total_concat_dim_numel(); + + // Reconstruct inp_volume_texel_sizes from Concat.cpp + ivec4 inp_volume_texel_sizes = inp_volume_sizes; + inp_volume_texel_sizes[out_packed_dim] = DIV_UP_4( + inp_volume_texel_sizes[out_packed_dim] + ) + 1; + + // tensor index of the first element that will be read from the input volume + ivec4 inp_volume_start_tidx = nchwi_to_tidx(tid, inp_volume_texel_sizes); + inp_volume_start_tidx[out_packed_dim] = MUL_4( + inp_volume_start_tidx[out_packed_dim] + ); + + int concat_offset = t_concat_offset[0]; + + // tensor index of the first element that will be written to the output tensor + ivec4 out_write_start_tidx = inp_volume_start_tidx; + out_write_start_tidx[concat_dim] += concat_offset; + + // To write to the the desired output element, we will need to load the texel + // to which the element belongs. Calculate the tensor index of the first + // element of that texel. + ivec4 out_read_start_tidx = out_write_start_tidx; + out_read_start_tidx[out_packed_dim] = ALIGN_DOWN_4( + out_write_start_tidx[out_packed_dim]); + + // bounds check + if (any(greaterThanEqual(out_read_start_tidx, out_sizes))) { return; } - if (can_use_fast_path()) { - // Fast path: No texel merging required - ivec4 in_tidx = out_tidx; + ivec3 out_pos = tidx_to_pos( + out_read_start_tidx, + out_sizes, + out_axis_map, + out_packed_dim + ); - $for i in range(NUM_INPUTS): - // For each input tensor, check if the tensor index is within bounds. If - // so, read the texel from the input tensor and write it to the output - if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { - const ivec3 in_pos = tidx_to_pos(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim); - const VEC4_T in_texel = load_texel(t_in${i+1}, in_pos); - write_texel_lpos(t_out, lpos, in_texel, out_axis_map); - return; - } - // Otherwise, adjust the index along the concat dimension and try the next - // input tensor. - else { - in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; - } - } - else { - // Slow path: Texel merging required - VEC4_T out_texel = VEC4_T(0); + VEC4_T out_texel = imageLoad(t_out, out_pos); - // Process each element in the output texel individually - for (int texel_i = 0; texel_i < 4; ++texel_i) { - ivec4 curr_out_tidx = out_tidx; - curr_out_tidx[out_packed_dim] += texel_i; + VEC4_T test_texel = VEC4_T(-1.0); - // Skip if we're out of bounds - if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) { - continue; - } + for (int comp = 0; comp < 4; ++comp) { + ivec4 out_tidx = out_read_start_tidx; + out_tidx[out_packed_dim] += comp; - ivec4 in_tidx = curr_out_tidx; - $for i in range(NUM_INPUTS): - // For each input tensor, check if the tensor index is within bounds. If - // so, read the corresponding texel element from the input tensor and - // write it to the output texel. - if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { - const ivec4 in_posi = tidx_to_posi(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim); - out_texel[texel_i] = load_texel(t_in${i+1}, in_posi.xyz)[in_posi.w]; - continue; - } - // Otherwise, adjust the index along the concat dimension and try the - // next input tensor. - else { - in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; - } + + // It's possible that the current texel element has been written to as part + // of the previous input batch; if so, then don't overwrite this texel + // element + if (out_tidx[concat_dim] < concat_offset) { + test_texel[comp] = -5.0; + continue; } - write_texel_lpos(t_out, lpos, out_texel, out_axis_map); + // Calculate the tidx of the input volume that corresponds to this output + // element + ivec4 inp_volume_tidx = out_tidx; + inp_volume_tidx[concat_dim] -= concat_offset; + + // go through the list of input tensors, and figure out which input this + // output element should be read from. + $for i in range(NUM_INPUTS): + if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) { + // Special fast path case if, for the first output texel element, the + // corresponding input element is at the start of the texel it belongs + // to. In this case, the input texel can be written as-is to the output + // texel. Also require that The entire input texel is valid and does not + // contain any padding elements. + if (comp == 0 && + out_tidx[out_packed_dim] % 4 == 0 && + inp_volume_tidx[inp${i}_packed_dim] % 4 == 0 && + inp_volume_tidx[inp${i}_packed_dim] + 3 < inp${i}_sizes[inp${i}_packed_dim]) { + const ivec3 in_pos = tidx_to_pos( + inp_volume_tidx, + inp${i}_sizes, + inp${i}_axis_map, + inp${i}_packed_dim); + + out_texel = texelFetch(t_inp${i}, in_pos, 0); + break; + } + + // Otherwise, locate the specific input element required + const ivec4 in_posi = tidx_to_posi( + inp_volume_tidx, + inp${i}_sizes, + inp${i}_axis_map, + inp${i}_packed_dim); + + out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w]; + test_texel[comp] = out_texel[comp]; + continue; + } + else { + inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim]; + } } + + imageStore(t_out, out_pos, out_texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/concat_utils.glslh new file mode 100644 index 00000000000..000b86a7fce --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_utils.glslh @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONCAT_UTILS_H +#define CONCAT_UTILS_H + + +/********************************** + * Concatenation utililty functions + * + */ + +/* + * Returns the total number of elements along the concatenation dim that will + * be concatenated in this input batch. + */ +$for N in range(1, 4): + #if NUM_INPUTS == ${N} + int total_concat_dim_numel() { + int total = 0; + $for i in range(N): + total += inp${i}_sizes[concat_dim]; + + return total; + } + #endif + +#endif // CONCAT_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 72650bb7040..fdb6f514a3e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -68,6 +68,8 @@ */ #define mod4(x) ((x) & 3) +#define ALIGN_DOWN_4(x) ((x) & ~3) + #define ALIGN_UP_4(x) (((x) + 3) & ~3) #define DIV_UP_8(x) (((x) + 7) >> 3) @@ -110,6 +112,10 @@ ivec4 tidx_to_4bufi( return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim]; } +/* + * Given a buffer index to a contiguous tensor and the tensor's sizes, return + * the tensor index that corresponds to the buffer index. + */ ivec4 nchwi_to_tidx(const int nchwi, const ivec4 sizes) { const int nchwi_div_x = nchwi / sizes.x; const int nchwi_div_y = nchwi_div_x / sizes.y; diff --git a/backends/vulkan/runtime/graph/ops/glsl/set_zero.glsl b/backends/vulkan/runtime/graph/ops/glsl/set_zero.glsl new file mode 100644 index 00000000000..d01780b9e30 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/set_zero.glsl @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} + +${layout_declare_ubo(B, "int", "out_numel")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int out_bufi = ivec3(gl_GlobalInvocationID).x; + if (out_bufi >= out_numel) { + return; + } + + t_out[out_bufi] = T(0); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/set_zero.yaml b/backends/vulkan/runtime/graph/ops/glsl/set_zero.yaml new file mode 100644 index 00000000000..cee87c468b1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/set_zero.yaml @@ -0,0 +1,8 @@ +set_zero: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: int32 + shader_variants: + - NAME: set_zero diff --git a/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.glsl new file mode 100644 index 00000000000..ba02da1c301 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.glsl @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "concat_offset", DTYPE, "buffer")} + +${layout_declare_ubo(B, "int", "concat_dim")} + +$for i in range(NUM_INPUTS): + ${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + // Only one thread needs to update the offset + if (gl_GlobalInvocationID.x != 0) { + return; + } + + // Sum up the sizes along the concat dimension for all input tensors + int total_size = 0; + $for i in range(NUM_INPUTS): + total_size += in${i+1}_sizes[concat_dim]; + + // Add to the current offset + concat_offset[0] += T(total_size); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.yaml new file mode 100644 index 00000000000..35e8740e0a3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.yaml @@ -0,0 +1,13 @@ +update_concat_offset: + parameter_names_with_default_values: + DTYPE: float + NUM_INPUTS: 2 + generate_variant_forall: + DTYPE: + - VALUE: int32 + shader_variants: + - NAME: update_concat_offset_1 + NUM_INPUTS: 1 + - NAME: update_concat_offset_2 + - NAME: update_concat_offset_3 + NUM_INPUTS: 3 diff --git a/backends/vulkan/runtime/graph/ops/impl/Concat.cpp b/backends/vulkan/runtime/graph/ops/impl/Concat.cpp index 315dabdb1d5..0a4acb6cef3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Concat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Concat.cpp @@ -19,15 +19,16 @@ namespace vkcompute { std::vector get_concat_sizes( ComputeGraph& graph, - const std::vector& in_value_refs, - const int64_t dim) { + ValueRef all_input_refs, + const int64_t concat_dim) { + ValueListPtr in_value_refs = graph.get_value_list(all_input_refs); // Get the sizes of the first input tensor as a starting point - std::vector new_out_sizes = graph.sizes_of(in_value_refs.at(0)); + std::vector new_out_sizes = graph.sizes_of(in_value_refs->at(0)); // Sum up the sizes along the concatenation dimension - for (size_t i = 1; i < in_value_refs.size(); ++i) { - const std::vector in_sizes = graph.sizes_of(in_value_refs.at(i)); - new_out_sizes.at(dim) += in_sizes.at(dim); + for (size_t i = 1; i < in_value_refs->size(); ++i) { + const std::vector in_sizes = graph.sizes_of(in_value_refs->at(i)); + new_out_sizes.at(concat_dim) += in_sizes.at(concat_dim); } return new_out_sizes; @@ -37,24 +38,122 @@ void resize_concat_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - // Extract relevant ValueRefs - const ValueRef out_ref = args.at(0).refs.at(0); - const std::vector& in_value_refs = args.at(1).refs; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef all_inputs = extra_args.at(0); - int64_t dim = graph->extract_scalar(extra_args.at(0)); + int64_t concat_dim = graph->extract_scalar(extra_args.at(1)); - // Normalize dim if negative - const int64_t ndim = graph->dim_of(out_ref); - if (dim < 0) { - dim += ndim; + // Normalize concat_dim if negative + const int64_t ndim = graph->dim_of(out); + if (concat_dim < 0) { + concat_dim += ndim; } // Calculate the new sizes std::vector new_out_sizes = - get_concat_sizes(*graph, in_value_refs, dim); + get_concat_sizes(*graph, all_inputs, concat_dim); // Resize the output tensor - graph->virtual_resize(out_ref, new_out_sizes); + graph->virtual_resize(out, new_out_sizes); +} + +utils::uvec3 concat_pick_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& extra_args) { + (void)shader; + (void)extra_args; + + const ValueRef out = args.at(0).refs.at(0); + const std::vector inputs_in_batch = args.at(1).refs; + + int64_t concat_dim = graph->extract_scalar(extra_args.at(1)); + + // Normalize concat_dim if negative + const int64_t ndim = graph->dim_of(out); + if (concat_dim < 0) { + concat_dim += ndim; + } + + // The concat shader concatenates N input tensors at a time to the output + // tensor. Since the shader may need to be invoked multiple times to finish + // concatenation when the number of input tensors is >N, the global workgroup + // is based on the volume of input data being concatenated in this batch, + // as opposed to the overall size of the output tensor. Conceptually, the + // global work group size represents which elements of the output tensor will + // be written to during this dispatch. + + uint32_t total_input_numel = 0; + int64_t concat_dim_numel = 0; + for (const ValueRef input : inputs_in_batch) { + total_input_numel += graph->numel_of(input); + concat_dim_numel += graph->size_at(concat_dim, input); + } + + if (graph->is_buffer_storage(out)) { + return {total_input_numel, 1, 1}; + } + + // The texture implementation is similar, except each invocation writes out 4 + // output elements along the packed dim (i.e. one texel). In this case, the + // global work group size represents the number of output texels that will be + // written to in this batch, rather than the number of output elements. Note + // that to update an element of the output, the entire texel that contains it + // will need to be loaded, updated, then written back. + + std::vector inp_volume_sizes = graph->sizes_of(out); + inp_volume_sizes.at(concat_dim) = concat_dim_numel; + + // Calculate what the image extents would be of a tensor with the input + // volume's sizes. This produces the number of texels that would need to be + // written to. + const int32_t packed_dim = graph->packed_dim_of(out); + std::vector inp_volume_texel_sizes = + api::calculate_padded_sizes(inp_volume_sizes, packed_dim); + // If the concat_dim is the same as the packed dim, and the concat_offset for + // this input batch is not a multiple of 4, then the data from an input texel + // may be split up between two output texels. For example: + // I0 , I1 , I2 , I2 + // O0 , O1 , O2 , X | X , X , X , X + // Therefore, 1 texel is added to the packed dim to account for this. + inp_volume_texel_sizes.at(3 - packed_dim) = + utils::div_up_4(inp_volume_texel_sizes.at(3 - packed_dim)) + 1; + + const uint32_t inp_volume_texel_numel = + utils::multiply_integers(inp_volume_texel_sizes); + + return {inp_volume_texel_numel, 1, 1}; + + // The texture implementation is similar, expect each thread is responsible + // for writing out an entire output texel. Therefore, the overall global work + // group size will be the concatenation of the texture extents of the input + // tensors in this batch. + + // One complication is when the previous concatenation batch does not write + // up to a texel boundary. An example is if the previous concatenation batch + // only wrote 7 elements along the concatenation dim. The first input element + // would then have to be inserted at the last element of the final texel + // written by the previous batch. To account for this, initialize the + // workgroup size at the concatenation dim to 1 (need to read N total texels + // along the concat dim for input tensors + up to 1 texel from the output + // tensor). + + // The axis along which to concatenate the input texture extents + int64_t extent_concat_axis = nchw_dim_to_whcn_dim(concat_dim, ndim); + // For batch concatenation, the concat axis is the batch-concatenation axis + if (concat_dim == 4) { + extent_concat_axis = graph->concat_dim_of(out); + } + + utils::uvec3 global_workgroup_size = graph->create_global_wg_size(out); + global_workgroup_size[concat_dim] = 0; + for (const ValueRef input : inputs_in_batch) { + utils::uvec3 texture_extents = graph->logical_limits_of(input); + global_workgroup_size[extent_concat_axis] += texture_extents[concat_dim]; + } + + return global_workgroup_size; } void add_concat_node( @@ -67,10 +166,6 @@ void add_concat_node( { const ValueListPtr tensors = graph.get_value_list(tensors_ref); - VK_CHECK_COND( - tensors->size() <= 3, - "Currently only concatenation of <= 3 tensors is supported"); - for (const ValueRef in : *tensors) { in_value_refs.push_back(in); } @@ -87,68 +182,161 @@ void add_concat_node( const int64_t dim_whcn = nchw_dim_to_whcn_dim(normalized_dim, ndim); const ValueRef dim_whcn_ref = graph.get_or_add_value_for_int(dim_whcn); - vkapi::ParamsBindList param_buffers = { - graph.get_or_create_int_param_buffer(dim_whcn_ref, 0)}; + // Create a temporary tensor to hold the concat offset + TmpTensor concat_offset( + &graph, {1}, vkapi::kInt, utils::kBuffer, utils::kWidthPacked); - std::vector push_constants; - vkapi::SpecVarList spec_vars; - - if (graph.is_buffer_storage(out)) { - param_buffers.append(graph.sizes_ubo(out)); - param_buffers.append(graph.strides_ubo(out)); + // Add node to set concat_offset to 0 + { + std::string kernel_name = "set_zero"; + add_dtype_suffix(kernel_name, graph.dtype_of(concat_offset)); + + vkapi::ParamsBindList param_buffers = {graph.numel_ubo(concat_offset)}; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + {1, 1, 1}, + {1, 1, 1}, + // Inputs and Outputs + {{concat_offset, vkapi::kWrite}}, + // Parameter buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + nullptr)); + } - for (const ValueRef in_ref : in_value_refs) { - param_buffers.append(graph.sizes_ubo(in_ref)); - param_buffers.append(graph.strides_ubo(in_ref)); + // Process inputs in batches of up to 3 tensors + const size_t batch_size = 3; + for (size_t batch_start = 0; batch_start < in_value_refs.size(); + batch_start += batch_size) { + const size_t batch_end = + std::min(batch_start + batch_size, in_value_refs.size()); + const size_t current_batch_size = batch_end - batch_start; + + std::vector batch_inputs; + for (size_t i = batch_start; i < batch_end; ++i) { + batch_inputs.push_back(in_value_refs.at(i)); } - param_buffers.append(graph.numel_ubo(out)); - - spec_vars = {graph.hashed_layout_of(out)}; - } else { - push_constants = {graph.sizes_pc_of(out)}; - - spec_vars = {graph.hashed_layout_of(out)}; - - for (const ValueRef in_ref : in_value_refs) { - push_constants.push_back(graph.sizes_pc_of(in_ref)); - spec_vars.append(graph.hashed_layout_of(in_ref)); + // Add concat node for this batch + { + vkapi::ParamsBindList param_buffers = { + graph.get_or_create_int_param_buffer(dim_whcn_ref, 0)}; + + std::vector push_constants; + vkapi::SpecVarList spec_vars; + + if (graph.is_buffer_storage(out)) { + param_buffers.append(graph.sizes_ubo(out)); + param_buffers.append(graph.strides_ubo(out)); + + for (const ValueRef in_ref : batch_inputs) { + param_buffers.append(graph.sizes_ubo(in_ref)); + param_buffers.append(graph.strides_ubo(in_ref)); + } + + param_buffers.append(graph.numel_ubo(out)); + + spec_vars = {graph.hashed_layout_of(out)}; + } else { + push_constants = {graph.sizes_pc_of(out)}; + + spec_vars = {graph.hashed_layout_of(out)}; + + for (const ValueRef in_ref : batch_inputs) { + push_constants.push_back(graph.sizes_pc_of(in_ref)); + spec_vars.append(graph.hashed_layout_of(in_ref)); + } + } + + std::string kernel_name = "concat"; + if (current_batch_size == 1) { + kernel_name += "_1"; + } else if (current_batch_size == 2) { + kernel_name += "_2"; + } else if (current_batch_size == 3) { + kernel_name += "_3"; + } + if (graph.is_buffer_storage(out)) { + kernel_name += "_buffer"; + } else { + kernel_name += "_texture3d"; + } + + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + DispatchNode::ResizeFunction resize_fn = nullptr; + if (batch_start == 0) { + resize_fn = resize_concat_node; + } + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + concat_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kReadWrite}, + {batch_inputs, vkapi::kRead}, + {concat_offset, vkapi::kRead}}, + // Parameter buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {tensors_ref, dim_ref}, + // Resizing Logic + resize_fn)); } - } - std::string kernel_name = "concat"; - if (in_value_refs.size() == 1) { - kernel_name += "_1"; - } else if (in_value_refs.size() == 2) { - kernel_name += "_2"; - } else if (in_value_refs.size() == 3) { - kernel_name += "_3"; - } - if (graph.is_buffer_storage(out)) { - kernel_name += "_buffer"; - } else { - kernel_name += "_texture3d"; + // Add node to update concat_offset (except for the last batch) + if (batch_end < in_value_refs.size()) { + vkapi::ParamsBindList param_buffers = { + graph.get_or_create_int_param_buffer(dim_whcn_ref, 0)}; + + for (const ValueRef in_ref : batch_inputs) { + param_buffers.append(graph.sizes_ubo(in_ref)); + } + + std::string kernel_name = "update_concat_offset"; + if (current_batch_size == 1) { + kernel_name += "_1"; + } else if (current_batch_size == 2) { + kernel_name += "_2"; + } else if (current_batch_size == 3) { + kernel_name += "_3"; + } + add_dtype_suffix(kernel_name, graph.dtype_of(concat_offset)); + + vkapi::SpecVarList spec_vars = {}; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + {1u, 1u, 1u}, + {1u, 1u, 1u}, + // Inputs and Outputs + {{concat_offset, vkapi::kWrite}}, + // Parameter buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + nullptr)); + } } - - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {in_value_refs, vkapi::kRead}}, - // Parameter buffers - param_buffers, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {dim_ref}, - // Resizing Logic - resize_concat_node)); } void cat_tensor(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 5efcfc1ffb2..ff35188be3e 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1233,66 +1233,81 @@ def get_repeat_interleave_inputs(): @register_test_suite("aten.cat.default") def get_cat_inputs(): # TensorList must be specified as list of tuples - test_suite = VkTestSuite( - [ - # Cat on Height - ([(M, M, 3, 5), (M, M, 0, 5)], 2), - ([(S1, S1, 3, 5), (S1, S1, 0, 5)], 2), - ([(M, M, 3, 5), (M, M, 4, 5)], 2), - ([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2), - ([(M2, 3, 5), (M2, 4, 5)], 1), - ([(S1, 3, 5), (S1, 4, 5)], 1), - ([(3, 5), (4, 5)], 0), - ([(3, 5), (4, 5), (1, 5)], 0), - ( - [(3, 5)], - 0, - ), - # Cat on Width - ([(M, M, 5, 3), (M, M, 5, 4)], 3), - ([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3), - ([(M, 5, 3), (M, 5, 4)], 2), - ([(S1, 5, 3), (S1, 5, 4)], 2), - ([(5, 0), (5, 4)], 1), - ([(5, 3), (5, 4)], 1), - ([(5, 3), (5, 4), (5, 1)], 1), - ( - [(5, 4)], - 1, - ), - ([(5,), (6,)], 0), - # Cat on Batch - ([(M, S1, 5, 4), (M1, S1, 5, 4)], 0), - ([(S, S1, 5, 4), (S1, S1, 5, 4)], 0), - ([(S, M, 5, 4), (S1, M, 5, 4)], 0), - ([(S, XS, 5, 4), (S1, XS, 5, 4)], 0), - ([(S, S2, 5, 4), (S1, S2, 5, 4)], 0), - ( - [ - (3, 1, 2, 5), - (3, 1, 2, 5), - (3, 1, 2, 5), - ], - 0, - ), - # Cat on Channel - ([(M, 5, 4), (0, 5, 4), (M1, 5, 4)], 0), - ([(S, 5, 4), (0, 5, 4), (S2, 5, 4)], 0), - ([(M, 5, 4), (M1, 5, 4), (M2, 5, 4)], 0), - ([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0), - ([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0), - ([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1), - ([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1), - ( - [ - (XS, 1, 2, 5), - (XS, 1, 2, 5), - (XS, 1, 2, 5), - ], - 1, - ), - ] - ) + suite_inputs = [ + # Cat on Height + ([(M, M, 3, 5), (M, M, 0, 5)], 2), + ([(S1, S1, 3, 5), (S1, S1, 0, 5)], 2), + ([(M, M, 3, 5), (M, M, 4, 5)], 2), + ([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2), + ([(M2, 3, 5), (M2, 4, 5)], 1), + ([(S1, 3, 5), (S1, 4, 5)], 1), + ([(3, 5), (4, 5)], 0), + ([(3, 5), (4, 5), (1, 5)], 0), + ( + [(3, 5)], + 0, + ), + # Cat on Width + ([(M, M, 5, 3), (M, M, 5, 4)], 3), + ([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3), + ([(M, 5, 3), (M, 5, 4)], 2), + ([(S1, 5, 3), (S1, 5, 4)], 2), + ([(5, 0), (5, 4)], 1), + ([(5, 3), (5, 4)], 1), + ([(5, 3), (5, 4), (5, 1)], 1), + ( + [(5, 4)], + 1, + ), + ([(5,), (6,)], 0), + # Cat on Batch + ([(M, S1, 5, 4), (M1, S1, 5, 4)], 0), + ([(S, S1, 5, 4), (S1, S1, 5, 4)], 0), + ([(S, M, 5, 4), (S1, M, 5, 4)], 0), + ([(S, XS, 5, 4), (S1, XS, 5, 4)], 0), + ([(S, S2, 5, 4), (S1, S2, 5, 4)], 0), + ( + [ + (3, 1, 2, 5), + (3, 1, 2, 5), + (3, 1, 2, 5), + ], + 0, + ), + # Cat on Channel + ([(M, 5, 4), (0, 5, 4), (M1, 5, 4)], 0), + ([(S, 5, 4), (0, 5, 4), (S2, 5, 4)], 0), + ([(M, 5, 4), (M1, 5, 4), (M2, 5, 4)], 0), + ([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0), + ([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0), + ([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1), + ([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1), + ( + [ + (XS, 1, 2, 5), + (XS, 1, 2, 5), + (XS, 1, 2, 5), + ], + 1, + ), + ] + + high_number_cat_inputs = [] + for num_input in [6, 9]: + odd_size = (3, 7, 29, 31) + even_size = (3, 8, 29, 32) + ones = (3, 1, 1, 1) + + for input_size in [odd_size, even_size, ones]: + input_sizes = [input_size] * num_input + # Test cat on height, width, and batch dim + high_number_cat_inputs.append((input_sizes, 3)) + high_number_cat_inputs.append((input_sizes, 2)) + high_number_cat_inputs.append((input_sizes, 1)) + high_number_cat_inputs.append((input_sizes, 0)) + + test_suite = VkTestSuite(suite_inputs + high_number_cat_inputs) + test_suite.layouts = [ "utils::kWidthPacked", "utils::kChannelsPacked", diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 4799a22882d..6bf6a68090a 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -2150,3 +2150,147 @@ def forward(self, a, b): ) self.lower_module_and_test_output(custom_complex_module, sample_inputs) + + def test_vulkan_backend_cat_width_dynamic_shapes(self): + class CatWidthModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3, x4, x5, x6): + return torch.cat([x1, x2, x3, x4, x5, x6], dim=3) + + cat_width_module = CatWidthModule() + + # Create 6 tensors with different widths but same batch, channel, and height dimensions + sample_inputs = ( + torch.randn(size=(2, 3, 4, 5), dtype=torch.float32), # width=5 + torch.randn(size=(2, 3, 4, 3), dtype=torch.float32), # width=3 + torch.randn(size=(2, 3, 4, 7), dtype=torch.float32), # width=7 + torch.randn(size=(2, 3, 4, 2), dtype=torch.float32), # width=2 + torch.randn(size=(2, 3, 4, 4), dtype=torch.float32), # width=4 + torch.randn(size=(2, 3, 4, 6), dtype=torch.float32), # width=6 + ) + + # Define dynamic shapes for the width dimension (dim=3) for each input + width1 = Dim("width1", min=1, max=10) + width2 = Dim("width2", min=1, max=10) + width3 = Dim("width3", min=1, max=10) + width4 = Dim("width4", min=1, max=10) + width5 = Dim("width5", min=1, max=10) + width6 = Dim("width6", min=1, max=10) + + dynamic_shapes = { + "x1": {3: width1}, + "x2": {3: width2}, + "x3": {3: width3}, + "x4": {3: width4}, + "x5": {3: width5}, + "x6": {3: width6}, + } + + # Create test inputs with different width combinations + test_inputs = [ + ( + torch.randn(2, 3, 4, 2), # width=2 + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 3), # width=3 + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 2), # width=2 + torch.randn(2, 3, 4, 4), # width=4 + ), + ( + torch.randn(2, 3, 4, 8), # width=8 + torch.randn(2, 3, 4, 2), # width=2 + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 3), # width=3 + torch.randn(2, 3, 4, 5), # width=5 + torch.randn(2, 3, 4, 1), # width=1 + ), + ( + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 9), # width=9 + torch.randn(2, 3, 4, 2), # width=2 + torch.randn(2, 3, 4, 4), # width=4 + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 3), # width=3 + ), + ] + + self.lower_module_and_test_output( + cat_width_module, + sample_inputs, + dynamic_shapes=dynamic_shapes, + test_inputs=test_inputs, + ) + + def test_vulkan_backend_cat_channels_dynamic_shapes(self): + class CatChannelsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3, x4, x5, x6): + return torch.cat([x1, x2, x3, x4, x5, x6], dim=1) + + cat_channels_module = CatChannelsModule() + + # Create 6 tensors with different channel counts but same batch, height, and width dimensions + sample_inputs = ( + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=4 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=2 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=6 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=1 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=3 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=5 + ) + + # Define dynamic shapes for the channels dimension (dim=1) for each input + channels1 = Dim("channels1", min=1, max=8) + channels2 = Dim("channels2", min=1, max=8) + channels3 = Dim("channels3", min=1, max=8) + channels4 = Dim("channels4", min=1, max=8) + channels5 = Dim("channels5", min=1, max=8) + channels6 = Dim("channels6", min=1, max=8) + + dynamic_shapes = { + "x1": {1: channels1}, + "x2": {1: channels2}, + "x3": {1: channels3}, + "x4": {1: channels4}, + "x5": {1: channels5}, + "x6": {1: channels6}, + } + + # Create test inputs with different channel combinations + test_inputs = [ + ( + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 2, 8, 6), # channels=2 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 3, 8, 6), # channels=3 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 2, 8, 6), # channels=2 + ), + ( + torch.randn(2, 6, 8, 6), # channels=6 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 3, 8, 6), # channels=3 + torch.randn(2, 2, 8, 6), # channels=2 + torch.randn(2, 4, 8, 6), # channels=4 + torch.randn(2, 1, 8, 6), # channels=1 + ), + ( + torch.randn(2, 2, 8, 6), # channels=2 + torch.randn(2, 7, 8, 6), # channels=7 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 3, 8, 6), # channels=3 + torch.randn(2, 2, 8, 6), # channels=2 + ), + ] + + self.lower_module_and_test_output( + cat_channels_module, + sample_inputs, + dynamic_shapes=dynamic_shapes, + test_inputs=test_inputs, + ) From 72c636729b3420a9501fedfab9e89eb3227e083b Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 8 Aug 2025 12:47:50 -0700 Subject: [PATCH 2/2] [ET-VK][ez] Fix registration for convolution operator Pull Request resolved: https://github.com/pytorch/executorch/pull/13227 ## Context Update the registration of the convolution operator to indicate that the weight tensor is prepacked and should not undergo normal texture limits checking. The current registration may cause valid convolution operators to not be partitioned since the export logic will think the weight tensor is non representable using channels packed textures. An example weight size would be something like [256, 256, 1, 1] which would result in a texture with extents [1, 1, 16384] which may exceed texture limits on some machines. Differential Revision: [D79893086](https://our.internmc.facebook.com/intern/diff/D79893086/) ghstack-source-id: 301766081 --- backends/vulkan/op_registry.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index e3498cf1792..675143cd7fd 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -435,7 +435,19 @@ def register_2d_pool_op(): ) def register_convolution_op(): return OpFeatures( - inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + inputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE, # input + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # stride (non tensor) + utils.NO_STORAGE, # padding (non tensor) + utils.NO_STORAGE, # dilation (non tensor) + utils.NO_STORAGE, # transposed (non tensor) + utils.NO_STORAGE, # output_padding (non tensor) + utils.NO_STORAGE, # groups (non tensor) + utils.NO_STORAGE, # output_min (non tensor) + utils.NO_STORAGE, # output_max (non tensor) + ], supports_resize=True, supports_prepacking=True, )