From e196ed9340cef36fcd1bcc2188f8fefdb16b4ee2 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 30 May 2025 08:26:09 -0700 Subject: [PATCH] [ET-VK] Add reshape functions for transformers related operators ## Changes * Implement resize functions for several operators used in Transformers models ## Motivation Be able to support batched prefill for llama models. Differential Revision: [D75686049](https://our.internmc.facebook.com/intern/diff/D75686049/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 25 ++-- .../vulkan/runtime/graph/ops/impl/Permute.cpp | 114 ++++++++++++------ .../vulkan/runtime/graph/ops/impl/Permute.h | 6 +- .../graph/ops/impl/RotaryEmbedding.cpp | 16 ++- .../vulkan/runtime/graph/ops/impl/Squeeze.cpp | 21 ++-- .../runtime/graph/ops/impl/Unsqueeze.cpp | 11 +- .../graph/ops/impl/utils/TensorUtils.cpp | 7 ++ .../graph/ops/impl/utils/TensorUtils.h | 6 + 8 files changed, 141 insertions(+), 65 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 0486110ced6..2aa940dcc4b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -500,7 +500,12 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures): return features -@update_features(["llama::update_cache", "llama::custom_sdpa"]) +@update_features( + [ + "llama::update_cache", + "llama::custom_sdpa", + ] +) def register_sdpa_ops(features: OpFeatures): features.resize_fn = False features.buffer_impl = False @@ -520,8 +525,17 @@ def register_rotary_emb_op(features: OpFeatures): return features -@update_features(exir_ops.edge.aten.view_copy.default) -def register_view_op(features: OpFeatures): +@update_features( + [ + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.select_copy.int, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.view_copy.default, + ] +) +def register_view_ops(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims=all_packed_dims, ) @@ -538,10 +552,8 @@ def register_view_op(features: OpFeatures): # Indexing and lookup exir_ops.edge.aten.flip.default, exir_ops.edge.aten.index_select.default, - exir_ops.edge.aten.select_copy.int, # Tensor creation exir_ops.edge.aten.arange.start_step, - exir_ops.edge.aten.clone.default, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, @@ -564,12 +576,9 @@ def register_ported_op(features: OpFeatures): # Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions @update_features( [ - # Indexing and lookup - exir_ops.edge.aten.slice_copy.Tensor, # Shape Manipulation exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.unsqueeze_copy.default, - exir_ops.edge.aten.permute_copy.default, # Tensor combination exir_ops.edge.aten.cat.default, exir_ops.edge.aten.repeat.default, diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 8e2c72d7627..8a602fc63bd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -25,10 +25,11 @@ using utils::uvec4; namespace { void check_args( - const api::vTensor& in, - const std::vector& permute_dims, - const api::vTensor& out) { - VK_CHECK_COND(check_same_packed_dim(in, out)); + ComputeGraph& graph, + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out) { + VK_CHECK_COND(check_same_packed_dim(graph, in, out)); // This implementation doesn't not requires the input tensor to have the same // dim size as the argument. The code will work as long as the input tensor's @@ -38,40 +39,93 @@ void check_args( } // namespace +void resize_permute_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + ValueRef out = args[0].refs[0]; + ValueRef in = args[1].refs[0]; + + std::vector in_sizes = graph->sizes_of(in); + std::vector out_sizes = graph->sizes_of(out); + + std::vector permute_dims = + graph->extract_int_or_symint_list(extra_args[0]); + + if (in_sizes.size() == out_sizes.size() && + in_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + int64_t out_ndim = std::max(in_sizes.size(), out_sizes.size()); + for (int i = 0; i < out_ndim; i++) { + int64_t permute_dim = permute_dims.at(i); + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + graph->virtual_resize(out, new_out_sizes); + } + // Case where permute is being used to implement squeeze + else if ( + in_sizes.size() > out_sizes.size() && + in_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + int offset = in_sizes.size() - out_sizes.size(); + for (int i = 0; i < out_sizes.size(); i++) { + int64_t permute_dim = permute_dims.at(i + offset); + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + graph->virtual_resize(out, new_out_sizes); + } + // Case where Permute is being used to implement unsqueeze + else if ( + in_sizes.size() < out_sizes.size() && + out_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + int offset = out_sizes.size() - in_sizes.size(); + for (int i = 0; i < out_sizes.size(); i++) { + int64_t permute_dim = permute_dims.at(i) - offset; + if (permute_dim >= 0) { + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + } + graph->virtual_resize(out, new_out_sizes); + } else { + VK_THROW("Invalid permute dims"); + } +} + void add_permute_node( ComputeGraph& graph, - ValueRef in, - const std::vector& permute_dims, - ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - - check_args(*t_in, permute_dims, *t_out); + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out) { + check_args(graph, in, permute_dims, out); ivec4 out_dims{0, 1, 2, 3}; // Special cases of squeeze/unsqueeze. Because the input dim size can be // different with output dim size. So pick t_in->dim() if squeeze, and // t_out->dim() if unsqueeze to create parameter for permute. - int64_t out_ndim = std::max(t_in->dim(), t_out->dim()); + int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out)); std::vector seen(out_ndim); - for (int i = 0; i < out_ndim; i++) { - int64_t permute_dim = permute_dims[i]; - VK_CHECK_COND( - !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); - seen[permute_dim] = true; - - out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim); + { + IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); + for (int i = 0; i < out_ndim; i++) { + int64_t permute_dim = permute_dims_ptr->at(i); + VK_CHECK_COND( + !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); + seen[permute_dim] = true; + + out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim); + } } std::string kernel_name = "permute"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); - int32_t out_channels = dim_at(t_out->sizes()); - int32_t in_channels = dim_at(t_in->sizes()); + int32_t out_channels = dim_at(graph.sizes_of(out)); + int32_t in_channels = dim_at(graph.sizes_of(in)); - const auto packed_dim = graph.packed_dim_of(in); + const int32_t packed_dim = graph.packed_dim_of(in); ivec2 channel_info = {out_channels, in_channels}; if (packed_dim == WHCN::kChannelsDim) { channel_info[0] = utils::align_up_4(channel_info[0]); @@ -95,19 +149,9 @@ void add_permute_node( // Specialization Constants spec_vars, // Resize Args - {}, + {permute_dims}, // Resizing Logic - nullptr)); -} - -void add_permute_node( - ComputeGraph& graph, - ValueRef in, - ValueRef permute_dims_ref, - ValueRef out) { - IntListPtr permute_dims = graph.get_int_list(permute_dims_ref); - - add_permute_node(graph, in, *permute_dims, out); + resize_permute_node)); } void permute(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.h b/backends/vulkan/runtime/graph/ops/impl/Permute.h index 941a8896fe2..0f17a4a26b0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.h +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.h @@ -18,8 +18,8 @@ namespace vkcompute { void add_permute_node( ComputeGraph& graph, - ValueRef in, - const std::vector& permute_dims, - ValueRef out); + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out); } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index ee40a043ee5..c1387580611 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -17,12 +17,18 @@ void resize_rotary_embedding_node( const std::vector& args, const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); - std::vector in_sizes = in->sizes(); - // UNCOMMENT BELOW IF NEEDED - // out->virtual_resize(in_sizes); + // Get output tensors (xq_out and xk_out) + vTensorPtr xq_out = graph->get_tensor(args[0].refs[0]); + vTensorPtr xk_out = graph->get_tensor(args[0].refs[1]); + + // Get input tensors (xq and xk) + vTensorPtr xq = graph->get_tensor(args[1].refs[0]); + vTensorPtr xk = graph->get_tensor(args[1].refs[1]); + + // Resize output tensors to match input tensors + xq_out->virtual_resize(xq->sizes()); + xk_out->virtual_resize(xk->sizes()); } void add_rotary_embedding_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp index b212d24f06b..fef43e99f84 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp @@ -20,25 +20,26 @@ void add_squeeze_copy_dims_node( ValueRef in, ValueRef dims_ref, ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); + int64_t in_dim = graph.dim_of(in); + std::vector in_sizes = graph.sizes_of(in); + std::vector out_sizes = graph.sizes_of(in); - IntListPtr dims = graph.get_int_list(dims_ref); + std::vector dims = graph.extract_int_or_symint_list(dims_ref); std::vector squeeze_dims; // Filter out edge cases that we don't need squeeze: // 1. The size of squeeze dim is larger than 1. // 2. Squeeze outter most dim // For these cases, just pass input to output via clone. - for (int i = 0; i < dims->size(); ++i) { - if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) { - squeeze_dims.push_back(dims->at(i)); + for (int i = 0; i < dims.size(); ++i) { + if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) { + squeeze_dims.push_back(dims.at(i)); } } if (squeeze_dims.size() == 0) { add_clone_node(graph, in, out); } else { - std::vector permute_dims(t_in->dim()); - for (int i = 0; i < t_in->dim(); ++i) { + std::vector permute_dims(in_dim); + for (int i = 0; i < in_dim; ++i) { permute_dims.at(i) = i; } for (auto& elem : squeeze_dims) { @@ -48,7 +49,9 @@ void add_squeeze_copy_dims_node( std::rotate(permute_dims.begin(), it, it + 1); } - add_permute_node(graph, in, permute_dims, out); + ValueRef permute_dims_ref = + graph.add_scalar_list(std::vector(permute_dims)); + add_permute_node(graph, in, permute_dims_ref, out); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index c8ada796e8e..cf96ad53063 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -19,14 +19,13 @@ void add_unsqueeze_node( ValueRef in, ValueRef dim_ref, ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); + int64_t in_dim = graph.dim_of(in); + int64_t out_dim = graph.dim_of(out); VK_CHECK_COND( - t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); + in_dim < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); int64_t dim = graph.extract_scalar(dim_ref); - int64_t out_dim = t_out->dim(); std::vector permute_dims(out_dim); for (int i = 1; i <= dim; i++) { @@ -38,7 +37,9 @@ void add_unsqueeze_node( permute_dims[i] = i; } - add_permute_node(graph, in, permute_dims, out); + ValueRef permute_dims_ref = + graph.add_scalar_list(std::vector(permute_dims)); + add_permute_node(graph, in, permute_dims_ref, out); } void unsqueeze(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp index 9d010c794ec..2bcf2a3842f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp @@ -57,6 +57,13 @@ bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) { return t1.packed_dim() == t2.packed_dim(); } +bool check_same_packed_dim( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out) { + return graph.packed_dim_of(in) == graph.packed_dim_of(out); +} + bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h index c9eeb0efe08..3b61083069e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -9,6 +9,7 @@ #pragma once #include +#include namespace vkcompute { @@ -38,6 +39,11 @@ bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim); bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2); +bool check_same_packed_dim( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out); + bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2,