From 7cb027659b0beb553c8ecd5a78323df1fbad685d Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 8 Aug 2025 08:04:45 -0700 Subject: [PATCH] [ET-VK] Allow `aten.cat.default` to handle any number of input tensors ## Context Previously, I updated the implementation of `aten.cat.default` in D76305343 (https://github.com/pytorch/executorch/pull/11508) since the original implementation had a bug. The new implementation only supported up to 3 input tensors, but several models require the need for up to 6 input tensors. This diff updates the capabilities of the `concat` op so that any arbitrary number of input tensors may be accepted. ## Changes * Update implementation of the concat shader to be able to be called repeatedly, allowing support for any number of input tensors. Differential Revision: [D79893084](https://our.internmc.facebook.com/intern/diff/D79893084/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 8 - .../vulkan/runtime/api/containers/Tensor.cpp | 9 +- .../vulkan/runtime/graph/ops/ExecuteNode.h | 2 +- .../runtime/graph/ops/glsl/concat_buffer.glsl | 61 +++- .../graph/ops/glsl/concat_texture.glsl | 193 ++++++---- .../runtime/graph/ops/glsl/concat_utils.glslh | 33 ++ .../runtime/graph/ops/glsl/indexing_utils.h | 6 + .../runtime/graph/ops/glsl/set_zero.glsl | 33 ++ .../runtime/graph/ops/glsl/set_zero.yaml | 8 + .../graph/ops/glsl/update_concat_offset.glsl | 42 +++ .../graph/ops/glsl/update_concat_offset.yaml | 13 + .../vulkan/runtime/graph/ops/impl/Concat.cpp | 338 ++++++++++++++---- backends/vulkan/test/op_tests/cases.py | 135 +++---- backends/vulkan/test/test_vulkan_delegate.py | 144 ++++++++ 14 files changed, 790 insertions(+), 235 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/concat_utils.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/set_zero.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/set_zero.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.yaml diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 22a93ec0e2b..e3498cf1792 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -491,17 +491,9 @@ def register_view_ops(): # for both texture and buffer storage types. @update_features(exir_ops.edge.aten.cat.default) def register_cat_op(): - def check_cat_node(node: torch.fx.Node) -> bool: - inputs = node.args[0] - if isinstance(inputs, (list, tuple)) and len(inputs) <= 3: - return True - - return False - return OpFeatures( inputs_storage=utils.ANY_STORAGE, supports_resize=True, - are_node_inputs_supported_fn=check_cat_node, ) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 64f330de59c..a3d9bd4aa34 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -517,6 +517,7 @@ void vTensorStorage::transition( vkapi::MemoryAccessFlags prev_access = last_access_.access; const bool prev_written = (prev_access & vkapi::MemoryAccessType::WRITE) != 0; + const bool cur_written = (cur_access & vkapi::MemoryAccessType::WRITE) != 0; VkImageLayout cur_layout = VK_IMAGE_LAYOUT_UNDEFINED; VkImageLayout new_layout = VK_IMAGE_LAYOUT_UNDEFINED; @@ -528,7 +529,13 @@ void vTensorStorage::transition( layout_changed = cur_layout != new_layout; } - if (prev_written || layout_changed) { + // RAW: need to make sure current read sees previous writes + // WAW: need to make sure the current write occurs after previous write so + // the final value is correct. + // WAR: need to make sure previous read does not read the value from the + // current write. + // RAR: no need for synchronization + if (prev_written || cur_written || layout_changed) { VkPipelineStageFlags src_stage = vkapi::vk_stage(prev_stage); if (0u == src_stage) { src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 6a815b246ef..4ea1ba57796 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -43,7 +43,7 @@ class ExecuteNode { friend class ComputeGraph; public: - using ResizeFunction = const std::function&, const std::vector&)>; diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl index 895cecb413a..e34ecaf8309 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl @@ -20,10 +20,12 @@ layout(std430) buffer; #include "indexing_utils.h" -${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "rw", "t_out", DTYPE, "buffer")} $for i in range(NUM_INPUTS): - ${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "buffer")} + ${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_concat_offset", "int", "buffer")} ${layout_declare_ubo(B, "int", "concat_dim")} @@ -31,8 +33,8 @@ ${layout_declare_ubo(B, "ivec4", "out_sizes")} ${layout_declare_ubo(B, "ivec4", "out_strides")} $for i in range(NUM_INPUTS): - ${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")} - ${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_strides")} + ${layout_declare_ubo(B, "ivec4", "inp" + str(i) + "_sizes")} + ${layout_declare_ubo(B, "ivec4", "inp" + str(i) + "_strides")} ${layout_declare_ubo(B, "int", "out_numel")} @@ -42,28 +44,53 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#define NUM_INPUTS ${NUM_INPUTS} + +#include "concat_utils.glslh" + +/* + * This shader template concatenates up to NUM_INPUT input tensors to the + * output tensor along the concat_dim. Elements from the input tensor will + * be inserted along the output's concat_dim starting at concat_offset. + */ void main() { - const int out_bufi = ivec3(gl_GlobalInvocationID).x; - if (out_bufi >= out_numel) { + const int tid = ivec3(gl_GlobalInvocationID).x; + + // The 1-3 input tensors are interpreted as one concatenated tensor ("volume") + // along the concat_dim for the purposes of tensor indexing. Each thread is + // responsible for reading one item from this volume and writing it to the + // appropriate output location. + ivec4 inp_volume_sizes = out_sizes; + inp_volume_sizes[concat_dim] = total_concat_dim_numel(); + + // Account for 0 size input tensors + if (any(lessThanEqual(inp_volume_sizes, ivec4(0)))) { + return; + } + + ivec4 inp_volume_tidx = nchwi_to_tidx(tid, inp_volume_sizes); + + // bounds check + if (any(greaterThanEqual(inp_volume_tidx, inp_volume_sizes))) { return; } - // Convert buffer linear index to 4-D tensor index for output - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); + int concat_offset = t_concat_offset[0]; + + ivec4 out_tidx = inp_volume_tidx; + out_tidx[concat_dim] += concat_offset; - // Determine which input tensor to read from - ivec4 in_tidx = out_tidx; + const uint out_bufi = tidx_to_bufi(out_tidx, out_strides); + // Go through the list of input tensors, and find which input this output + // element should be read from. $for i in range(NUM_INPUTS): - // Check if the index at the concat dim is within bounds of the input tensor - // If so, read from that input tensor and write to output - if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { - int in_bufi = tidx_to_bufi(in_tidx, in${i+1}_strides); - t_out[out_bufi] = t_in${i+1}[in_bufi]; + if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) { + int inp_bufi = tidx_to_bufi(inp_volume_tidx, inp${i}_strides); + t_out[out_bufi] = t_inp${i}[inp_bufi]; return; } - // otherwise, decrement the index at the concat dim else { - in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; + inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim]; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl index dac6266bf67..afab0c524d6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl @@ -19,16 +19,18 @@ layout(std430) buffer; #include "indexing_utils.h" -${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "rw", "t_out", DTYPE, "texture3d")} $for i in range(NUM_INPUTS): - ${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "texture3d")} + ${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "texture3d")} + +${layout_declare_tensor(B, "r", "t_concat_offset", "int", "buffer")} ${layout_declare_ubo(B, "int", "concat_dim")} $in_metadata = "" $for i in range(NUM_INPUTS): - $in_metadata += "ivec4 in" + str(i + 1) + "_sizes;\n" + $in_metadata += "ivec4 inp" + str(i) + "_sizes;\n" layout(push_constant) uniform restrict Block { ivec4 out_sizes; @@ -40,90 +42,135 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); const lowp int out_packed_dim = unhash_packed_dim(out_layout); $for i in range(NUM_INPUTS): - ${layout_declare_spec_const(C, "int", "in" + str(i+1) + "_layout", "DEFAULT_LAYOUT")} - const lowp ivec4 in${i+1}_axis_map = unhash_axis_map(in${i+1}_layout); - const lowp int in${i+1}_packed_dim = unhash_packed_dim(in${i+1}_layout); + ${layout_declare_spec_const(C, "int", "inp" + str(i) + "_layout", "DEFAULT_LAYOUT")} + const lowp ivec4 inp${i}_axis_map = unhash_axis_map(inp${i}_layout); + const lowp int inp${i}_packed_dim = unhash_packed_dim(inp${i}_layout); layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// Check if we can use the fast path (no texel merging required) -bool can_use_fast_path() { - // Fast path is possible when: - // 1. The concat dimension is not the packed dimension, or - // 2. The concat dimension is the packed dimension but both input tensors have dimensions - // that are multiples of 4 along the packed dimension - if (concat_dim != out_packed_dim) { - return true; - } - - // Check if all input tensors have dimensions that are multiples of 4 along the packed dimension - bool all_concat_dim_size_multiple_of_4 = true; - $for i in range(NUM_INPUTS): - all_concat_dim_size_multiple_of_4 = - all_concat_dim_size_multiple_of_4 && - (in${i+1}_sizes[concat_dim] % 4 == 0); +#define NUM_INPUTS ${NUM_INPUTS} - return all_concat_dim_size_multiple_of_4; -} +#include "concat_utils.glslh" +/* + * This shader template concatenates up to NUM_INPUT input tensors to the + * output tensor along the concat_dim. Elements from the input tensor will + * be inserted along the output's concat_dim starting at concat_offset. + * + * Each thread is responsible for writing out one output texel. The data + * required for the output texel may be read from multiple input texels of one + * input tensor. + */ void main() { - const ivec3 lpos = ivec3(gl_GlobalInvocationID); - ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim); - - if (any(greaterThanEqual(out_tidx, out_sizes))) { + const int tid = ivec3(gl_GlobalInvocationID).x; + + // Sum of the sizes of all input tensors along the concat_dim + const int concat_numel = total_concat_dim_numel(); + + // The 1-3 input tensors are interpreted as one concatenated tensor ("volume") + // along the concat_dim for the purposes of tensor indexing. Each thread is + // responsible for writing out 4 elements along the packed dim of the output + // tensor by reading the source data from the input tensor(s). + ivec4 inp_volume_sizes = out_sizes; + inp_volume_sizes[concat_dim] = total_concat_dim_numel(); + + // Reconstruct inp_volume_texel_sizes from Concat.cpp + ivec4 inp_volume_texel_sizes = inp_volume_sizes; + inp_volume_texel_sizes[out_packed_dim] = DIV_UP_4( + inp_volume_texel_sizes[out_packed_dim] + ) + 1; + + // tensor index of the first element that will be read from the input volume + ivec4 inp_volume_start_tidx = nchwi_to_tidx(tid, inp_volume_texel_sizes); + inp_volume_start_tidx[out_packed_dim] = MUL_4( + inp_volume_start_tidx[out_packed_dim] + ); + + int concat_offset = t_concat_offset[0]; + + // tensor index of the first element that will be written to the output tensor + ivec4 out_write_start_tidx = inp_volume_start_tidx; + out_write_start_tidx[concat_dim] += concat_offset; + + // To write to the the desired output element, we will need to load the texel + // to which the element belongs. Calculate the tensor index of the first + // element of that texel. + ivec4 out_read_start_tidx = out_write_start_tidx; + out_read_start_tidx[out_packed_dim] = ALIGN_DOWN_4( + out_write_start_tidx[out_packed_dim]); + + // bounds check + if (any(greaterThanEqual(out_read_start_tidx, out_sizes))) { return; } - if (can_use_fast_path()) { - // Fast path: No texel merging required - ivec4 in_tidx = out_tidx; + ivec3 out_pos = tidx_to_pos( + out_read_start_tidx, + out_sizes, + out_axis_map, + out_packed_dim + ); - $for i in range(NUM_INPUTS): - // For each input tensor, check if the tensor index is within bounds. If - // so, read the texel from the input tensor and write it to the output - if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { - const ivec3 in_pos = tidx_to_pos(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim); - const VEC4_T in_texel = load_texel(t_in${i+1}, in_pos); - write_texel_lpos(t_out, lpos, in_texel, out_axis_map); - return; - } - // Otherwise, adjust the index along the concat dimension and try the next - // input tensor. - else { - in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; - } - } - else { - // Slow path: Texel merging required - VEC4_T out_texel = VEC4_T(0); + VEC4_T out_texel = imageLoad(t_out, out_pos); - // Process each element in the output texel individually - for (int texel_i = 0; texel_i < 4; ++texel_i) { - ivec4 curr_out_tidx = out_tidx; - curr_out_tidx[out_packed_dim] += texel_i; + VEC4_T test_texel = VEC4_T(-1.0); - // Skip if we're out of bounds - if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) { - continue; - } + for (int comp = 0; comp < 4; ++comp) { + ivec4 out_tidx = out_read_start_tidx; + out_tidx[out_packed_dim] += comp; - ivec4 in_tidx = curr_out_tidx; - $for i in range(NUM_INPUTS): - // For each input tensor, check if the tensor index is within bounds. If - // so, read the corresponding texel element from the input tensor and - // write it to the output texel. - if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { - const ivec4 in_posi = tidx_to_posi(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim); - out_texel[texel_i] = load_texel(t_in${i+1}, in_posi.xyz)[in_posi.w]; - continue; - } - // Otherwise, adjust the index along the concat dimension and try the - // next input tensor. - else { - in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; - } + + // It's possible that the current texel element has been written to as part + // of the previous input batch; if so, then don't overwrite this texel + // element + if (out_tidx[concat_dim] < concat_offset) { + test_texel[comp] = -5.0; + continue; } - write_texel_lpos(t_out, lpos, out_texel, out_axis_map); + // Calculate the tidx of the input volume that corresponds to this output + // element + ivec4 inp_volume_tidx = out_tidx; + inp_volume_tidx[concat_dim] -= concat_offset; + + // go through the list of input tensors, and figure out which input this + // output element should be read from. + $for i in range(NUM_INPUTS): + if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) { + // Special fast path case if, for the first output texel element, the + // corresponding input element is at the start of the texel it belongs + // to. In this case, the input texel can be written as-is to the output + // texel. Also require that The entire input texel is valid and does not + // contain any padding elements. + if (comp == 0 && + out_tidx[out_packed_dim] % 4 == 0 && + inp_volume_tidx[inp${i}_packed_dim] % 4 == 0 && + inp_volume_tidx[inp${i}_packed_dim] + 3 < inp${i}_sizes[inp${i}_packed_dim]) { + const ivec3 in_pos = tidx_to_pos( + inp_volume_tidx, + inp${i}_sizes, + inp${i}_axis_map, + inp${i}_packed_dim); + + out_texel = texelFetch(t_inp${i}, in_pos, 0); + break; + } + + // Otherwise, locate the specific input element required + const ivec4 in_posi = tidx_to_posi( + inp_volume_tidx, + inp${i}_sizes, + inp${i}_axis_map, + inp${i}_packed_dim); + + out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w]; + test_texel[comp] = out_texel[comp]; + continue; + } + else { + inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim]; + } } + + imageStore(t_out, out_pos, out_texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/concat_utils.glslh new file mode 100644 index 00000000000..000b86a7fce --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_utils.glslh @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONCAT_UTILS_H +#define CONCAT_UTILS_H + + +/********************************** + * Concatenation utililty functions + * + */ + +/* + * Returns the total number of elements along the concatenation dim that will + * be concatenated in this input batch. + */ +$for N in range(1, 4): + #if NUM_INPUTS == ${N} + int total_concat_dim_numel() { + int total = 0; + $for i in range(N): + total += inp${i}_sizes[concat_dim]; + + return total; + } + #endif + +#endif // CONCAT_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 72650bb7040..fdb6f514a3e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -68,6 +68,8 @@ */ #define mod4(x) ((x) & 3) +#define ALIGN_DOWN_4(x) ((x) & ~3) + #define ALIGN_UP_4(x) (((x) + 3) & ~3) #define DIV_UP_8(x) (((x) + 7) >> 3) @@ -110,6 +112,10 @@ ivec4 tidx_to_4bufi( return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim]; } +/* + * Given a buffer index to a contiguous tensor and the tensor's sizes, return + * the tensor index that corresponds to the buffer index. + */ ivec4 nchwi_to_tidx(const int nchwi, const ivec4 sizes) { const int nchwi_div_x = nchwi / sizes.x; const int nchwi_div_y = nchwi_div_x / sizes.y; diff --git a/backends/vulkan/runtime/graph/ops/glsl/set_zero.glsl b/backends/vulkan/runtime/graph/ops/glsl/set_zero.glsl new file mode 100644 index 00000000000..d01780b9e30 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/set_zero.glsl @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} + +${layout_declare_ubo(B, "int", "out_numel")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int out_bufi = ivec3(gl_GlobalInvocationID).x; + if (out_bufi >= out_numel) { + return; + } + + t_out[out_bufi] = T(0); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/set_zero.yaml b/backends/vulkan/runtime/graph/ops/glsl/set_zero.yaml new file mode 100644 index 00000000000..cee87c468b1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/set_zero.yaml @@ -0,0 +1,8 @@ +set_zero: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: int32 + shader_variants: + - NAME: set_zero diff --git a/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.glsl new file mode 100644 index 00000000000..ba02da1c301 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.glsl @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "concat_offset", DTYPE, "buffer")} + +${layout_declare_ubo(B, "int", "concat_dim")} + +$for i in range(NUM_INPUTS): + ${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + // Only one thread needs to update the offset + if (gl_GlobalInvocationID.x != 0) { + return; + } + + // Sum up the sizes along the concat dimension for all input tensors + int total_size = 0; + $for i in range(NUM_INPUTS): + total_size += in${i+1}_sizes[concat_dim]; + + // Add to the current offset + concat_offset[0] += T(total_size); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.yaml new file mode 100644 index 00000000000..35e8740e0a3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/update_concat_offset.yaml @@ -0,0 +1,13 @@ +update_concat_offset: + parameter_names_with_default_values: + DTYPE: float + NUM_INPUTS: 2 + generate_variant_forall: + DTYPE: + - VALUE: int32 + shader_variants: + - NAME: update_concat_offset_1 + NUM_INPUTS: 1 + - NAME: update_concat_offset_2 + - NAME: update_concat_offset_3 + NUM_INPUTS: 3 diff --git a/backends/vulkan/runtime/graph/ops/impl/Concat.cpp b/backends/vulkan/runtime/graph/ops/impl/Concat.cpp index 315dabdb1d5..0a4acb6cef3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Concat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Concat.cpp @@ -19,15 +19,16 @@ namespace vkcompute { std::vector get_concat_sizes( ComputeGraph& graph, - const std::vector& in_value_refs, - const int64_t dim) { + ValueRef all_input_refs, + const int64_t concat_dim) { + ValueListPtr in_value_refs = graph.get_value_list(all_input_refs); // Get the sizes of the first input tensor as a starting point - std::vector new_out_sizes = graph.sizes_of(in_value_refs.at(0)); + std::vector new_out_sizes = graph.sizes_of(in_value_refs->at(0)); // Sum up the sizes along the concatenation dimension - for (size_t i = 1; i < in_value_refs.size(); ++i) { - const std::vector in_sizes = graph.sizes_of(in_value_refs.at(i)); - new_out_sizes.at(dim) += in_sizes.at(dim); + for (size_t i = 1; i < in_value_refs->size(); ++i) { + const std::vector in_sizes = graph.sizes_of(in_value_refs->at(i)); + new_out_sizes.at(concat_dim) += in_sizes.at(concat_dim); } return new_out_sizes; @@ -37,24 +38,122 @@ void resize_concat_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - // Extract relevant ValueRefs - const ValueRef out_ref = args.at(0).refs.at(0); - const std::vector& in_value_refs = args.at(1).refs; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef all_inputs = extra_args.at(0); - int64_t dim = graph->extract_scalar(extra_args.at(0)); + int64_t concat_dim = graph->extract_scalar(extra_args.at(1)); - // Normalize dim if negative - const int64_t ndim = graph->dim_of(out_ref); - if (dim < 0) { - dim += ndim; + // Normalize concat_dim if negative + const int64_t ndim = graph->dim_of(out); + if (concat_dim < 0) { + concat_dim += ndim; } // Calculate the new sizes std::vector new_out_sizes = - get_concat_sizes(*graph, in_value_refs, dim); + get_concat_sizes(*graph, all_inputs, concat_dim); // Resize the output tensor - graph->virtual_resize(out_ref, new_out_sizes); + graph->virtual_resize(out, new_out_sizes); +} + +utils::uvec3 concat_pick_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& extra_args) { + (void)shader; + (void)extra_args; + + const ValueRef out = args.at(0).refs.at(0); + const std::vector inputs_in_batch = args.at(1).refs; + + int64_t concat_dim = graph->extract_scalar(extra_args.at(1)); + + // Normalize concat_dim if negative + const int64_t ndim = graph->dim_of(out); + if (concat_dim < 0) { + concat_dim += ndim; + } + + // The concat shader concatenates N input tensors at a time to the output + // tensor. Since the shader may need to be invoked multiple times to finish + // concatenation when the number of input tensors is >N, the global workgroup + // is based on the volume of input data being concatenated in this batch, + // as opposed to the overall size of the output tensor. Conceptually, the + // global work group size represents which elements of the output tensor will + // be written to during this dispatch. + + uint32_t total_input_numel = 0; + int64_t concat_dim_numel = 0; + for (const ValueRef input : inputs_in_batch) { + total_input_numel += graph->numel_of(input); + concat_dim_numel += graph->size_at(concat_dim, input); + } + + if (graph->is_buffer_storage(out)) { + return {total_input_numel, 1, 1}; + } + + // The texture implementation is similar, except each invocation writes out 4 + // output elements along the packed dim (i.e. one texel). In this case, the + // global work group size represents the number of output texels that will be + // written to in this batch, rather than the number of output elements. Note + // that to update an element of the output, the entire texel that contains it + // will need to be loaded, updated, then written back. + + std::vector inp_volume_sizes = graph->sizes_of(out); + inp_volume_sizes.at(concat_dim) = concat_dim_numel; + + // Calculate what the image extents would be of a tensor with the input + // volume's sizes. This produces the number of texels that would need to be + // written to. + const int32_t packed_dim = graph->packed_dim_of(out); + std::vector inp_volume_texel_sizes = + api::calculate_padded_sizes(inp_volume_sizes, packed_dim); + // If the concat_dim is the same as the packed dim, and the concat_offset for + // this input batch is not a multiple of 4, then the data from an input texel + // may be split up between two output texels. For example: + // I0 , I1 , I2 , I2 + // O0 , O1 , O2 , X | X , X , X , X + // Therefore, 1 texel is added to the packed dim to account for this. + inp_volume_texel_sizes.at(3 - packed_dim) = + utils::div_up_4(inp_volume_texel_sizes.at(3 - packed_dim)) + 1; + + const uint32_t inp_volume_texel_numel = + utils::multiply_integers(inp_volume_texel_sizes); + + return {inp_volume_texel_numel, 1, 1}; + + // The texture implementation is similar, expect each thread is responsible + // for writing out an entire output texel. Therefore, the overall global work + // group size will be the concatenation of the texture extents of the input + // tensors in this batch. + + // One complication is when the previous concatenation batch does not write + // up to a texel boundary. An example is if the previous concatenation batch + // only wrote 7 elements along the concatenation dim. The first input element + // would then have to be inserted at the last element of the final texel + // written by the previous batch. To account for this, initialize the + // workgroup size at the concatenation dim to 1 (need to read N total texels + // along the concat dim for input tensors + up to 1 texel from the output + // tensor). + + // The axis along which to concatenate the input texture extents + int64_t extent_concat_axis = nchw_dim_to_whcn_dim(concat_dim, ndim); + // For batch concatenation, the concat axis is the batch-concatenation axis + if (concat_dim == 4) { + extent_concat_axis = graph->concat_dim_of(out); + } + + utils::uvec3 global_workgroup_size = graph->create_global_wg_size(out); + global_workgroup_size[concat_dim] = 0; + for (const ValueRef input : inputs_in_batch) { + utils::uvec3 texture_extents = graph->logical_limits_of(input); + global_workgroup_size[extent_concat_axis] += texture_extents[concat_dim]; + } + + return global_workgroup_size; } void add_concat_node( @@ -67,10 +166,6 @@ void add_concat_node( { const ValueListPtr tensors = graph.get_value_list(tensors_ref); - VK_CHECK_COND( - tensors->size() <= 3, - "Currently only concatenation of <= 3 tensors is supported"); - for (const ValueRef in : *tensors) { in_value_refs.push_back(in); } @@ -87,68 +182,161 @@ void add_concat_node( const int64_t dim_whcn = nchw_dim_to_whcn_dim(normalized_dim, ndim); const ValueRef dim_whcn_ref = graph.get_or_add_value_for_int(dim_whcn); - vkapi::ParamsBindList param_buffers = { - graph.get_or_create_int_param_buffer(dim_whcn_ref, 0)}; + // Create a temporary tensor to hold the concat offset + TmpTensor concat_offset( + &graph, {1}, vkapi::kInt, utils::kBuffer, utils::kWidthPacked); - std::vector push_constants; - vkapi::SpecVarList spec_vars; - - if (graph.is_buffer_storage(out)) { - param_buffers.append(graph.sizes_ubo(out)); - param_buffers.append(graph.strides_ubo(out)); + // Add node to set concat_offset to 0 + { + std::string kernel_name = "set_zero"; + add_dtype_suffix(kernel_name, graph.dtype_of(concat_offset)); + + vkapi::ParamsBindList param_buffers = {graph.numel_ubo(concat_offset)}; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + {1, 1, 1}, + {1, 1, 1}, + // Inputs and Outputs + {{concat_offset, vkapi::kWrite}}, + // Parameter buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + nullptr)); + } - for (const ValueRef in_ref : in_value_refs) { - param_buffers.append(graph.sizes_ubo(in_ref)); - param_buffers.append(graph.strides_ubo(in_ref)); + // Process inputs in batches of up to 3 tensors + const size_t batch_size = 3; + for (size_t batch_start = 0; batch_start < in_value_refs.size(); + batch_start += batch_size) { + const size_t batch_end = + std::min(batch_start + batch_size, in_value_refs.size()); + const size_t current_batch_size = batch_end - batch_start; + + std::vector batch_inputs; + for (size_t i = batch_start; i < batch_end; ++i) { + batch_inputs.push_back(in_value_refs.at(i)); } - param_buffers.append(graph.numel_ubo(out)); - - spec_vars = {graph.hashed_layout_of(out)}; - } else { - push_constants = {graph.sizes_pc_of(out)}; - - spec_vars = {graph.hashed_layout_of(out)}; - - for (const ValueRef in_ref : in_value_refs) { - push_constants.push_back(graph.sizes_pc_of(in_ref)); - spec_vars.append(graph.hashed_layout_of(in_ref)); + // Add concat node for this batch + { + vkapi::ParamsBindList param_buffers = { + graph.get_or_create_int_param_buffer(dim_whcn_ref, 0)}; + + std::vector push_constants; + vkapi::SpecVarList spec_vars; + + if (graph.is_buffer_storage(out)) { + param_buffers.append(graph.sizes_ubo(out)); + param_buffers.append(graph.strides_ubo(out)); + + for (const ValueRef in_ref : batch_inputs) { + param_buffers.append(graph.sizes_ubo(in_ref)); + param_buffers.append(graph.strides_ubo(in_ref)); + } + + param_buffers.append(graph.numel_ubo(out)); + + spec_vars = {graph.hashed_layout_of(out)}; + } else { + push_constants = {graph.sizes_pc_of(out)}; + + spec_vars = {graph.hashed_layout_of(out)}; + + for (const ValueRef in_ref : batch_inputs) { + push_constants.push_back(graph.sizes_pc_of(in_ref)); + spec_vars.append(graph.hashed_layout_of(in_ref)); + } + } + + std::string kernel_name = "concat"; + if (current_batch_size == 1) { + kernel_name += "_1"; + } else if (current_batch_size == 2) { + kernel_name += "_2"; + } else if (current_batch_size == 3) { + kernel_name += "_3"; + } + if (graph.is_buffer_storage(out)) { + kernel_name += "_buffer"; + } else { + kernel_name += "_texture3d"; + } + + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + DispatchNode::ResizeFunction resize_fn = nullptr; + if (batch_start == 0) { + resize_fn = resize_concat_node; + } + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + concat_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kReadWrite}, + {batch_inputs, vkapi::kRead}, + {concat_offset, vkapi::kRead}}, + // Parameter buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {tensors_ref, dim_ref}, + // Resizing Logic + resize_fn)); } - } - std::string kernel_name = "concat"; - if (in_value_refs.size() == 1) { - kernel_name += "_1"; - } else if (in_value_refs.size() == 2) { - kernel_name += "_2"; - } else if (in_value_refs.size() == 3) { - kernel_name += "_3"; - } - if (graph.is_buffer_storage(out)) { - kernel_name += "_buffer"; - } else { - kernel_name += "_texture3d"; + // Add node to update concat_offset (except for the last batch) + if (batch_end < in_value_refs.size()) { + vkapi::ParamsBindList param_buffers = { + graph.get_or_create_int_param_buffer(dim_whcn_ref, 0)}; + + for (const ValueRef in_ref : batch_inputs) { + param_buffers.append(graph.sizes_ubo(in_ref)); + } + + std::string kernel_name = "update_concat_offset"; + if (current_batch_size == 1) { + kernel_name += "_1"; + } else if (current_batch_size == 2) { + kernel_name += "_2"; + } else if (current_batch_size == 3) { + kernel_name += "_3"; + } + add_dtype_suffix(kernel_name, graph.dtype_of(concat_offset)); + + vkapi::SpecVarList spec_vars = {}; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + {1u, 1u, 1u}, + {1u, 1u, 1u}, + // Inputs and Outputs + {{concat_offset, vkapi::kWrite}}, + // Parameter buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + nullptr)); + } } - - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {in_value_refs, vkapi::kRead}}, - // Parameter buffers - param_buffers, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {dim_ref}, - // Resizing Logic - resize_concat_node)); } void cat_tensor(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 5efcfc1ffb2..ff35188be3e 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1233,66 +1233,81 @@ def get_repeat_interleave_inputs(): @register_test_suite("aten.cat.default") def get_cat_inputs(): # TensorList must be specified as list of tuples - test_suite = VkTestSuite( - [ - # Cat on Height - ([(M, M, 3, 5), (M, M, 0, 5)], 2), - ([(S1, S1, 3, 5), (S1, S1, 0, 5)], 2), - ([(M, M, 3, 5), (M, M, 4, 5)], 2), - ([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2), - ([(M2, 3, 5), (M2, 4, 5)], 1), - ([(S1, 3, 5), (S1, 4, 5)], 1), - ([(3, 5), (4, 5)], 0), - ([(3, 5), (4, 5), (1, 5)], 0), - ( - [(3, 5)], - 0, - ), - # Cat on Width - ([(M, M, 5, 3), (M, M, 5, 4)], 3), - ([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3), - ([(M, 5, 3), (M, 5, 4)], 2), - ([(S1, 5, 3), (S1, 5, 4)], 2), - ([(5, 0), (5, 4)], 1), - ([(5, 3), (5, 4)], 1), - ([(5, 3), (5, 4), (5, 1)], 1), - ( - [(5, 4)], - 1, - ), - ([(5,), (6,)], 0), - # Cat on Batch - ([(M, S1, 5, 4), (M1, S1, 5, 4)], 0), - ([(S, S1, 5, 4), (S1, S1, 5, 4)], 0), - ([(S, M, 5, 4), (S1, M, 5, 4)], 0), - ([(S, XS, 5, 4), (S1, XS, 5, 4)], 0), - ([(S, S2, 5, 4), (S1, S2, 5, 4)], 0), - ( - [ - (3, 1, 2, 5), - (3, 1, 2, 5), - (3, 1, 2, 5), - ], - 0, - ), - # Cat on Channel - ([(M, 5, 4), (0, 5, 4), (M1, 5, 4)], 0), - ([(S, 5, 4), (0, 5, 4), (S2, 5, 4)], 0), - ([(M, 5, 4), (M1, 5, 4), (M2, 5, 4)], 0), - ([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0), - ([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0), - ([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1), - ([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1), - ( - [ - (XS, 1, 2, 5), - (XS, 1, 2, 5), - (XS, 1, 2, 5), - ], - 1, - ), - ] - ) + suite_inputs = [ + # Cat on Height + ([(M, M, 3, 5), (M, M, 0, 5)], 2), + ([(S1, S1, 3, 5), (S1, S1, 0, 5)], 2), + ([(M, M, 3, 5), (M, M, 4, 5)], 2), + ([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2), + ([(M2, 3, 5), (M2, 4, 5)], 1), + ([(S1, 3, 5), (S1, 4, 5)], 1), + ([(3, 5), (4, 5)], 0), + ([(3, 5), (4, 5), (1, 5)], 0), + ( + [(3, 5)], + 0, + ), + # Cat on Width + ([(M, M, 5, 3), (M, M, 5, 4)], 3), + ([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3), + ([(M, 5, 3), (M, 5, 4)], 2), + ([(S1, 5, 3), (S1, 5, 4)], 2), + ([(5, 0), (5, 4)], 1), + ([(5, 3), (5, 4)], 1), + ([(5, 3), (5, 4), (5, 1)], 1), + ( + [(5, 4)], + 1, + ), + ([(5,), (6,)], 0), + # Cat on Batch + ([(M, S1, 5, 4), (M1, S1, 5, 4)], 0), + ([(S, S1, 5, 4), (S1, S1, 5, 4)], 0), + ([(S, M, 5, 4), (S1, M, 5, 4)], 0), + ([(S, XS, 5, 4), (S1, XS, 5, 4)], 0), + ([(S, S2, 5, 4), (S1, S2, 5, 4)], 0), + ( + [ + (3, 1, 2, 5), + (3, 1, 2, 5), + (3, 1, 2, 5), + ], + 0, + ), + # Cat on Channel + ([(M, 5, 4), (0, 5, 4), (M1, 5, 4)], 0), + ([(S, 5, 4), (0, 5, 4), (S2, 5, 4)], 0), + ([(M, 5, 4), (M1, 5, 4), (M2, 5, 4)], 0), + ([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0), + ([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0), + ([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1), + ([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1), + ( + [ + (XS, 1, 2, 5), + (XS, 1, 2, 5), + (XS, 1, 2, 5), + ], + 1, + ), + ] + + high_number_cat_inputs = [] + for num_input in [6, 9]: + odd_size = (3, 7, 29, 31) + even_size = (3, 8, 29, 32) + ones = (3, 1, 1, 1) + + for input_size in [odd_size, even_size, ones]: + input_sizes = [input_size] * num_input + # Test cat on height, width, and batch dim + high_number_cat_inputs.append((input_sizes, 3)) + high_number_cat_inputs.append((input_sizes, 2)) + high_number_cat_inputs.append((input_sizes, 1)) + high_number_cat_inputs.append((input_sizes, 0)) + + test_suite = VkTestSuite(suite_inputs + high_number_cat_inputs) + test_suite.layouts = [ "utils::kWidthPacked", "utils::kChannelsPacked", diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 4799a22882d..6bf6a68090a 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -2150,3 +2150,147 @@ def forward(self, a, b): ) self.lower_module_and_test_output(custom_complex_module, sample_inputs) + + def test_vulkan_backend_cat_width_dynamic_shapes(self): + class CatWidthModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3, x4, x5, x6): + return torch.cat([x1, x2, x3, x4, x5, x6], dim=3) + + cat_width_module = CatWidthModule() + + # Create 6 tensors with different widths but same batch, channel, and height dimensions + sample_inputs = ( + torch.randn(size=(2, 3, 4, 5), dtype=torch.float32), # width=5 + torch.randn(size=(2, 3, 4, 3), dtype=torch.float32), # width=3 + torch.randn(size=(2, 3, 4, 7), dtype=torch.float32), # width=7 + torch.randn(size=(2, 3, 4, 2), dtype=torch.float32), # width=2 + torch.randn(size=(2, 3, 4, 4), dtype=torch.float32), # width=4 + torch.randn(size=(2, 3, 4, 6), dtype=torch.float32), # width=6 + ) + + # Define dynamic shapes for the width dimension (dim=3) for each input + width1 = Dim("width1", min=1, max=10) + width2 = Dim("width2", min=1, max=10) + width3 = Dim("width3", min=1, max=10) + width4 = Dim("width4", min=1, max=10) + width5 = Dim("width5", min=1, max=10) + width6 = Dim("width6", min=1, max=10) + + dynamic_shapes = { + "x1": {3: width1}, + "x2": {3: width2}, + "x3": {3: width3}, + "x4": {3: width4}, + "x5": {3: width5}, + "x6": {3: width6}, + } + + # Create test inputs with different width combinations + test_inputs = [ + ( + torch.randn(2, 3, 4, 2), # width=2 + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 3), # width=3 + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 2), # width=2 + torch.randn(2, 3, 4, 4), # width=4 + ), + ( + torch.randn(2, 3, 4, 8), # width=8 + torch.randn(2, 3, 4, 2), # width=2 + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 3), # width=3 + torch.randn(2, 3, 4, 5), # width=5 + torch.randn(2, 3, 4, 1), # width=1 + ), + ( + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 9), # width=9 + torch.randn(2, 3, 4, 2), # width=2 + torch.randn(2, 3, 4, 4), # width=4 + torch.randn(2, 3, 4, 1), # width=1 + torch.randn(2, 3, 4, 3), # width=3 + ), + ] + + self.lower_module_and_test_output( + cat_width_module, + sample_inputs, + dynamic_shapes=dynamic_shapes, + test_inputs=test_inputs, + ) + + def test_vulkan_backend_cat_channels_dynamic_shapes(self): + class CatChannelsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3, x4, x5, x6): + return torch.cat([x1, x2, x3, x4, x5, x6], dim=1) + + cat_channels_module = CatChannelsModule() + + # Create 6 tensors with different channel counts but same batch, height, and width dimensions + sample_inputs = ( + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=4 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=2 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=6 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=1 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=3 + torch.randn(size=(2, 8, 8, 6), dtype=torch.float32), # channels=5 + ) + + # Define dynamic shapes for the channels dimension (dim=1) for each input + channels1 = Dim("channels1", min=1, max=8) + channels2 = Dim("channels2", min=1, max=8) + channels3 = Dim("channels3", min=1, max=8) + channels4 = Dim("channels4", min=1, max=8) + channels5 = Dim("channels5", min=1, max=8) + channels6 = Dim("channels6", min=1, max=8) + + dynamic_shapes = { + "x1": {1: channels1}, + "x2": {1: channels2}, + "x3": {1: channels3}, + "x4": {1: channels4}, + "x5": {1: channels5}, + "x6": {1: channels6}, + } + + # Create test inputs with different channel combinations + test_inputs = [ + ( + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 2, 8, 6), # channels=2 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 3, 8, 6), # channels=3 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 2, 8, 6), # channels=2 + ), + ( + torch.randn(2, 6, 8, 6), # channels=6 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 3, 8, 6), # channels=3 + torch.randn(2, 2, 8, 6), # channels=2 + torch.randn(2, 4, 8, 6), # channels=4 + torch.randn(2, 1, 8, 6), # channels=1 + ), + ( + torch.randn(2, 2, 8, 6), # channels=2 + torch.randn(2, 7, 8, 6), # channels=7 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 1, 8, 6), # channels=1 + torch.randn(2, 3, 8, 6), # channels=3 + torch.randn(2, 2, 8, 6), # channels=2 + ), + ] + + self.lower_module_and_test_output( + cat_channels_module, + sample_inputs, + dynamic_shapes=dynamic_shapes, + test_inputs=test_inputs, + )