From 72a67e2dc402c0edd4a0181c72109201ca4f87be Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Jun 2025 12:24:43 -0700 Subject: [PATCH 1/3] [ET-VK] Add reshape functions for transformers related operators Pull Request resolved: https://github.com/pytorch/executorch/pull/11256 ## Changes * Implement resize functions for several operators used in Transformers models ## Motivation Be able to support batched prefill for llama models. ghstack-source-id: 287935585 @exported-using-ghexport Differential Revision: [D75686049](https://our.internmc.facebook.com/intern/diff/D75686049/) --- backends/vulkan/op_registry.py | 25 ++-- .../vulkan/runtime/graph/ops/impl/Permute.cpp | 120 ++++++++++++------ .../vulkan/runtime/graph/ops/impl/Permute.h | 6 +- .../graph/ops/impl/RotaryEmbedding.cpp | 22 ++-- .../vulkan/runtime/graph/ops/impl/Squeeze.cpp | 27 ++-- .../runtime/graph/ops/impl/Unsqueeze.cpp | 17 +-- .../graph/ops/impl/utils/TensorUtils.cpp | 7 + .../graph/ops/impl/utils/TensorUtils.h | 6 + 8 files changed, 154 insertions(+), 76 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 0486110ced6..2aa940dcc4b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -500,7 +500,12 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures): return features -@update_features(["llama::update_cache", "llama::custom_sdpa"]) +@update_features( + [ + "llama::update_cache", + "llama::custom_sdpa", + ] +) def register_sdpa_ops(features: OpFeatures): features.resize_fn = False features.buffer_impl = False @@ -520,8 +525,17 @@ def register_rotary_emb_op(features: OpFeatures): return features -@update_features(exir_ops.edge.aten.view_copy.default) -def register_view_op(features: OpFeatures): +@update_features( + [ + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.select_copy.int, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.view_copy.default, + ] +) +def register_view_ops(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims=all_packed_dims, ) @@ -538,10 +552,8 @@ def register_view_op(features: OpFeatures): # Indexing and lookup exir_ops.edge.aten.flip.default, exir_ops.edge.aten.index_select.default, - exir_ops.edge.aten.select_copy.int, # Tensor creation exir_ops.edge.aten.arange.start_step, - exir_ops.edge.aten.clone.default, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, @@ -564,12 +576,9 @@ def register_ported_op(features: OpFeatures): # Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions @update_features( [ - # Indexing and lookup - exir_ops.edge.aten.slice_copy.Tensor, # Shape Manipulation exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.unsqueeze_copy.default, - exir_ops.edge.aten.permute_copy.default, # Tensor combination exir_ops.edge.aten.cat.default, exir_ops.edge.aten.repeat.default, diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 8e2c72d7627..fba3f03467b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -25,10 +25,12 @@ using utils::uvec4; namespace { void check_args( - const api::vTensor& in, - const std::vector& permute_dims, - const api::vTensor& out) { - VK_CHECK_COND(check_same_packed_dim(in, out)); + ComputeGraph& graph, + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out) { + (void)permute_dims; + VK_CHECK_COND(check_same_packed_dim(graph, in, out)); // This implementation doesn't not requires the input tensor to have the same // dim size as the argument. The code will work as long as the input tensor's @@ -38,40 +40,94 @@ void check_args( } // namespace +void resize_permute_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args[0].refs[0]; + const ValueRef in = args[1].refs[0]; + + const std::vector in_sizes = graph->sizes_of(in); + const std::vector out_sizes = graph->sizes_of(out); + + const std::vector permute_dims = + graph->extract_int_or_symint_list(resize_args[0]); + + if (in_sizes.size() == out_sizes.size() && + in_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + const int64_t out_ndim = std::max(in_sizes.size(), out_sizes.size()); + for (int i = 0; i < out_ndim; i++) { + const int64_t permute_dim = permute_dims.at(i); + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + graph->virtual_resize(out, new_out_sizes); + } + // Case where permute is being used to implement squeeze + else if ( + in_sizes.size() > out_sizes.size() && + in_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + const size_t offset = in_sizes.size() - out_sizes.size(); + for (int i = 0; i < out_sizes.size(); i++) { + const int64_t permute_dim = permute_dims.at(i + offset); + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + graph->virtual_resize(out, new_out_sizes); + } + // Case where Permute is being used to implement unsqueeze + else if ( + in_sizes.size() < out_sizes.size() && + out_sizes.size() == permute_dims.size()) { + std::vector new_out_sizes(out_sizes.size(), 1); + const size_t offset = out_sizes.size() - in_sizes.size(); + for (int i = 0; i < out_sizes.size(); i++) { + int64_t permute_dim = permute_dims.at(i) - offset; + if (permute_dim >= 0) { + new_out_sizes.at(i) = in_sizes.at(permute_dim); + } + } + graph->virtual_resize(out, new_out_sizes); + } else { + VK_THROW("Invalid permute dims"); + } +} + void add_permute_node( ComputeGraph& graph, - ValueRef in, - const std::vector& permute_dims, - ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - - check_args(*t_in, permute_dims, *t_out); + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out) { + check_args(graph, in, permute_dims, out); ivec4 out_dims{0, 1, 2, 3}; // Special cases of squeeze/unsqueeze. Because the input dim size can be - // different with output dim size. So pick t_in->dim() if squeeze, and - // t_out->dim() if unsqueeze to create parameter for permute. - int64_t out_ndim = std::max(t_in->dim(), t_out->dim()); + // different with output dim size. So pick graph.dim_of(in) if squeeze, and + // graph.dim_of(out) if unsqueeze to create parameter for permute. + const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out)); std::vector seen(out_ndim); - for (int i = 0; i < out_ndim; i++) { - int64_t permute_dim = permute_dims[i]; - VK_CHECK_COND( - !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); - seen[permute_dim] = true; - - out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim); + { + IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); + for (int i = 0; i < out_ndim; i++) { + int64_t permute_dim = permute_dims_ptr->at(i); + VK_CHECK_COND( + !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); + seen[permute_dim] = true; + + out_dims[(4u - out_ndim) + i] = + utils::safe_downcast(permute_dim + (4 - out_ndim)); + } } std::string kernel_name = "permute"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); - int32_t out_channels = dim_at(t_out->sizes()); - int32_t in_channels = dim_at(t_in->sizes()); + const int32_t out_channels = dim_at(graph.sizes_of(out)); + const int32_t in_channels = dim_at(graph.sizes_of(in)); - const auto packed_dim = graph.packed_dim_of(in); + const int32_t packed_dim = graph.packed_dim_of(in); ivec2 channel_info = {out_channels, in_channels}; if (packed_dim == WHCN::kChannelsDim) { channel_info[0] = utils::align_up_4(channel_info[0]); @@ -95,19 +151,9 @@ void add_permute_node( // Specialization Constants spec_vars, // Resize Args - {}, + {permute_dims}, // Resizing Logic - nullptr)); -} - -void add_permute_node( - ComputeGraph& graph, - ValueRef in, - ValueRef permute_dims_ref, - ValueRef out) { - IntListPtr permute_dims = graph.get_int_list(permute_dims_ref); - - add_permute_node(graph, in, *permute_dims, out); + resize_permute_node)); } void permute(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.h b/backends/vulkan/runtime/graph/ops/impl/Permute.h index 941a8896fe2..0f17a4a26b0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.h +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.h @@ -18,8 +18,8 @@ namespace vkcompute { void add_permute_node( ComputeGraph& graph, - ValueRef in, - const std::vector& permute_dims, - ValueRef out); + const ValueRef in, + const ValueRef permute_dims, + const ValueRef out); } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index ee40a043ee5..31bab144d8a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -15,14 +15,20 @@ namespace vkcompute { void resize_rotary_embedding_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); - - std::vector in_sizes = in->sizes(); - // UNCOMMENT BELOW IF NEEDED - // out->virtual_resize(in_sizes); + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef xq_out = args.at(0).refs.at(0); + const ValueRef xk_out = args.at(0).refs.at(1); + + const ValueRef xq = args.at(1).refs.at(0); + const ValueRef xk = args.at(1).refs.at(1); + + const std::vector xq_sizes = graph->sizes_of(xq); + const std::vector xk_sizes = graph->sizes_of(xk); + + graph->virtual_resize(xq_out, xq_sizes); + graph->virtual_resize(xk_out, xk_sizes); } void add_rotary_embedding_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp index b212d24f06b..249f5e7fa6b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp @@ -17,28 +17,29 @@ namespace vkcompute { void add_squeeze_copy_dims_node( ComputeGraph& graph, - ValueRef in, - ValueRef dims_ref, - ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); + const ValueRef in, + const ValueRef dims_ref, + const ValueRef out) { + const int64_t in_dim = graph.dim_of(in); + const std::vector in_sizes = graph.sizes_of(in); + const std::vector out_sizes = graph.sizes_of(in); - IntListPtr dims = graph.get_int_list(dims_ref); + const std::vector dims = graph.extract_int_or_symint_list(dims_ref); std::vector squeeze_dims; // Filter out edge cases that we don't need squeeze: // 1. The size of squeeze dim is larger than 1. // 2. Squeeze outter most dim // For these cases, just pass input to output via clone. - for (int i = 0; i < dims->size(); ++i) { - if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) { - squeeze_dims.push_back(dims->at(i)); + for (int i = 0; i < dims.size(); ++i) { + if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) { + squeeze_dims.push_back(dims.at(i)); } } if (squeeze_dims.size() == 0) { add_clone_node(graph, in, out); } else { - std::vector permute_dims(t_in->dim()); - for (int i = 0; i < t_in->dim(); ++i) { + std::vector permute_dims(in_dim); + for (int i = 0; i < in_dim; ++i) { permute_dims.at(i) = i; } for (auto& elem : squeeze_dims) { @@ -48,7 +49,9 @@ void add_squeeze_copy_dims_node( std::rotate(permute_dims.begin(), it, it + 1); } - add_permute_node(graph, in, permute_dims, out); + const ValueRef permute_dims_ref = + graph.add_scalar_list(std::vector(permute_dims)); + add_permute_node(graph, in, permute_dims_ref, out); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index c8ada796e8e..306a79fb8b8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -16,17 +16,16 @@ namespace vkcompute { void add_unsqueeze_node( ComputeGraph& graph, - ValueRef in, - ValueRef dim_ref, - ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); + const ValueRef in, + const ValueRef dim_ref, + const ValueRef out) { + const int64_t in_dim = graph.dim_of(in); + const int64_t out_dim = graph.dim_of(out); VK_CHECK_COND( - t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); + in_dim < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); int64_t dim = graph.extract_scalar(dim_ref); - int64_t out_dim = t_out->dim(); std::vector permute_dims(out_dim); for (int i = 1; i <= dim; i++) { @@ -38,7 +37,9 @@ void add_unsqueeze_node( permute_dims[i] = i; } - add_permute_node(graph, in, permute_dims, out); + const ValueRef permute_dims_ref = + graph.add_scalar_list(std::vector(permute_dims)); + add_permute_node(graph, in, permute_dims_ref, out); } void unsqueeze(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp index 9d010c794ec..2bcf2a3842f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp @@ -57,6 +57,13 @@ bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) { return t1.packed_dim() == t2.packed_dim(); } +bool check_same_packed_dim( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out) { + return graph.packed_dim_of(in) == graph.packed_dim_of(out); +} + bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h index c9eeb0efe08..3b61083069e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -9,6 +9,7 @@ #pragma once #include +#include namespace vkcompute { @@ -38,6 +39,11 @@ bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim); bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2); +bool check_same_packed_dim( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out); + bool check_same_packed_dim( const api::vTensor& t1, const api::vTensor& t2, From 462806b2f7a8e9ab09b64d6fa0d0485232fdcfd0 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Jun 2025 12:48:54 -0700 Subject: [PATCH 2/3] [ET-VK] Add support for operator.add for symint ops Pull Request resolved: https://github.com/pytorch/executorch/pull/11257 ## Changes * Add an implementation for operator.add which add symbolic integers. ## Motivation Support executing llama models with dynamic shapes. This operator shows up when exporting with dynamic shapes. ghstack-source-id: 287945686 @exported-using-ghexport Differential Revision: [D75238029](https://our.internmc.facebook.com/intern/diff/D75238029/) --- .../vulkan/_passes/tag_memory_meta_pass.py | 4 +- backends/vulkan/op_registry.py | 1 + .../runtime/graph/ops/impl/SymIntOps.cpp | 69 ++++++++++++++----- .../serialization/vulkan_graph_builder.py | 26 ++++--- 4 files changed, 71 insertions(+), 29 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 8983ca67752..836a0c6ef7d 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import logging -from copy import deepcopy from typing import Any, Optional, Set import executorch.backends.vulkan.utils as utils @@ -22,6 +21,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.tensor import TensorSpec logger: logging.Logger = logging.getLogger("") logger.setLevel(logging.INFO) @@ -52,7 +52,7 @@ def insert_transition_node( (arg,), ) clone_node.meta["val"] = arg.meta["val"] - clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"]) clone_node.meta["spec"].const = False set_memory_metadata(clone_node, storage, layout) arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 2aa940dcc4b..3c1f1eb40dc 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -230,6 +230,7 @@ def update_features_impl(op: OpKey): exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # Symbolic integer ops torch.ops.aten.sym_size.int, + operator.add, ] ) def register_ephemeral_op(features: OpFeatures): diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp index 0705b53d394..f07522d2578 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -11,19 +11,27 @@ namespace vkcompute { +// +// sym_size +// + +void sym_size_impl(ComputeGraph* graph, const std::vector& args) { + const ValueRef in_tensor = args.at(0); + const ValueRef dim = args.at(1); + const ValueRef out_symint = args.at(2); + + const int64_t dim_val = graph->extract_scalar(dim); + const int64_t size_at_dim = graph->size_at(dim_val, in_tensor); + + graph->set_symint(out_symint, static_cast(size_at_dim)); +} + void resize_sym_size_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { + const std::vector& resize_args) { (void)args; // Unused parameter - - ValueRef out_symint_ref = extra_args[0]; - ValueRef in_tensor_ref = extra_args[1]; - - int64_t dim = graph->extract_scalar(extra_args[2]); - int64_t size_at_dim = graph->size_at(dim, in_tensor_ref); - - graph->set_symint(out_symint_ref, static_cast(size_at_dim)); + sym_size_impl(graph, resize_args); } /* @@ -32,21 +40,50 @@ void resize_sym_size_node( * specified dimension. */ void sym_size_int(ComputeGraph& graph, const std::vector& args) { - ValueRef in_tensor = args[0]; - ValueRef dim = args[1]; - ValueRef out_symint = args[2]; + sym_size_impl(&graph, args); + + graph.execute_nodes().emplace_back( + new ExecuteNode(resize_sym_size_node, args)); +} - int64_t dim_val = graph.extract_scalar(dim); +// +// binary operators +// - int64_t size_at_dim = graph.size_at(dim_val, in_tensor); - graph.set_symint(out_symint, static_cast(size_at_dim)); +void sym_add_impl(ComputeGraph* graph, const std::vector& args) { + const ValueRef a = args.at(0); + const ValueRef b = args.at(1); + const ValueRef out = args.at(2); + + const int32_t a_val = graph->read_symint(a); + const int32_t b_val = graph->read_symint(b); + const int32_t result = a_val + b_val; + + graph->set_symint(out, result); +} + +void resize_sym_add_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)args; // Unused parameter + sym_add_impl(graph, resize_args); +} + +/* + * This operator takes two symints as inputs and produces a symint as output. + * The output symint's value is the sum of the two input symints. + */ +void sym_add(ComputeGraph& graph, const std::vector& args) { + sym_add_impl(&graph, args); graph.execute_nodes().emplace_back( - new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim})); + new ExecuteNode(resize_sym_add_node, args)); } REGISTER_OPERATORS { VK_REGISTER_OP(sym_size.int, sym_size_int); + VK_REGISTER_OP(add, sym_add); } } // namespace vkcompute diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 76b9240463b..d21d33b75da 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -340,17 +340,21 @@ def process_call_function_node(self, node) -> None: self.seen_ops.add(node.target) - for i, schema_arg in enumerate(node.target._schema.arguments): - if not schema_arg.kwarg_only and i < len(node.args): - function_arg = node.args[i] - elif schema_arg.name in node.kwargs: - function_arg = node.kwargs[schema_arg.name] - else: - function_arg = schema_arg.default_value - - # Create a Value for each function argument. If the argument has been - # previously encountered, then use the existing Value id. - operator_call_args.append(self.get_or_create_value_for(function_arg)) + if hasattr(node.target, "_schema"): + for i, schema_arg in enumerate(node.target._schema.arguments): + if not schema_arg.kwarg_only and i < len(node.args): + function_arg = node.args[i] + elif schema_arg.name in node.kwargs: + function_arg = node.kwargs[schema_arg.name] + else: + function_arg = schema_arg.default_value + + # Create a Value for each function argument. If the argument has been + # previously encountered, then use the existing Value id. + operator_call_args.append(self.get_or_create_value_for(function_arg)) + else: + for _, arg_node in enumerate(node.args): + operator_call_args.append(self.get_or_create_value_for(arg_node)) # Add output node operator_call_args.append(self.create_node_value(node)) From 7bc04079e6d9c89ed417e78f0711592f052fe8ef Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Jun 2025 12:48:55 -0700 Subject: [PATCH 3/3] [ET-VK][ez] Fix handling of assert ops Pull Request resolved: https://github.com/pytorch/executorch/pull/11258 ## Changes * Apply `RemoveAssertsTransform` as part of `vulkan_preprocess` * Do not call `RemoveAssertsTransform` before lowering the graph * Register ops related to asserts to the operator registry as ephemeral ops ## Motivation assert ops are not implemented in Vulkan, so previously `RemoveAssertsTransform()` is called on the graph before the lowering process. However, it turns out that the assertion ops are required to properly handle dynamic shapes, because they place constraints on the possible range of symbolic integers. If they are not present, then re-tracing the graph during a recompile (which may occur during a graph transform pass) may fail. Therefore, instead of calling the transform before lowering, call it inside vulkan_preprocess after a point where subsequent passes will not attempt to trace the graph. ghstack-source-id: 287945687 @exported-using-ghexport Differential Revision: [D75686048](https://our.internmc.facebook.com/intern/diff/D75686048/) --- backends/vulkan/_passes/fuse_quantized_ops.py | 5 ++++- backends/vulkan/op_registry.py | 7 +++++++ backends/vulkan/partitioner/vulkan_partitioner.py | 7 ++++--- backends/vulkan/vulkan_preprocess.py | 2 ++ examples/models/llama/TARGETS | 1 - examples/models/llama/export_llama_lib.py | 4 ---- 6 files changed, 17 insertions(+), 9 deletions(-) diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index d510e1d4342..805a5c1f744 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -17,6 +17,7 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass ################# ## linear_qcnw ## @@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: ) graph_module.recompile() - graph_module = super().call(graph_module).graph_module + dead_code_elimination_pass(graph_module) + # Re-trace the graph since new nodes were (potentially) inserted + graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 3c1f1eb40dc..90fea61318c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -231,6 +231,13 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + # Guard and assert ops + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, ] ) def register_ephemeral_op(features: OpFeatures): diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index d690e886d40..cbf30f84196 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -146,10 +146,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex def node_is_compatible( self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: - if utils.is_symint_node(node): - return node.target in vulkan_supported_ops, "Op is compatible" - elif utils.is_tensor_node(node): + if utils.is_tensor_node(node): return self.op_node_is_compatible(node, features=features) + # For non-tensor nodes, just check if the op is registered + elif hasattr(node, "target"): + return node.target in vulkan_supported_ops, "Op is compatible" return False, f"Unsupported node type: {node.format_node()}" diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 4200df3e131..a22afc3f42e 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -29,6 +29,7 @@ SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) +from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( @@ -172,6 +173,7 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + RemoveAssertsTransform(), # Since this pass may replace a scalar argument with a tensor argument, # this pass may result in a non ATen compliant graph structure. RemoveLocalScalarDenseOpsTransform(), diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index f2aa396f7a1..872eccce872 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -148,7 +148,6 @@ runtime.python_library( ":source_transformation", "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", - "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/exir/passes:init_mutable_pass", "//executorch/examples/models:model_base", "//executorch/examples/models:models", diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a3102886f8..96faf64475e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -24,7 +24,6 @@ import pkg_resources import torch -from executorch.backends.vulkan._passes.remove_asserts import remove_asserts from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func @@ -880,9 +879,6 @@ def _to_edge_and_lower_llama( # noqa: C901 ) modelname = f"vulkan_{modelname}" - # Need to remove asserts from the graph to prevent graph breaks - remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) - if mps: partitioners.append(get_mps_partitioner(use_kv_cache)) modelname = f"mps_{modelname}"