From 05d3956d26ef22dc7c3bc53e60cd9dccf10321d3 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 30 May 2025 08:26:13 -0700 Subject: [PATCH] [ET-VK] Add support for binary symint ops ## Changes * Add an implementation for binary operators which add symbolic integers. ## Motivation Support executing llama models with dynamic shapes. This operator shows up when exporting with dynamic shapes. Differential Revision: [D75238029](https://our.internmc.facebook.com/intern/diff/D75238029/) [ghstack-poisoned] --- .../vulkan/_passes/tag_memory_meta_pass.py | 7 +- backends/vulkan/op_registry.py | 1 + .../runtime/graph/ops/impl/SymIntOps.cpp | 67 ++++++++++++++----- .../serialization/vulkan_graph_builder.py | 29 +++++--- 4 files changed, 77 insertions(+), 27 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 8983ca67752..667de2ae45f 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -22,6 +22,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.tensor import TensorSpec logger: logging.Logger = logging.getLogger("") logger.setLevel(logging.INFO) @@ -51,8 +52,12 @@ def insert_transition_node( exir_ops.edge.aten.clone.default, (arg,), ) + print(arg) + print(arg.meta["val"]) + print(arg.meta["spec"]) clone_node.meta["val"] = arg.meta["val"] - clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + # clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"]) clone_node.meta["spec"].const = False set_memory_metadata(clone_node, storage, layout) arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 2aa940dcc4b..3c1f1eb40dc 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -230,6 +230,7 @@ def update_features_impl(op: OpKey): exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # Symbolic integer ops torch.ops.aten.sym_size.int, + operator.add, ] ) def register_ephemeral_op(features: OpFeatures): diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp index 0705b53d394..8a1869a25f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -11,19 +11,27 @@ namespace vkcompute { +// +// sym_size +// + +void sym_size_impl(ComputeGraph* graph, const std::vector& args) { + ValueRef in_tensor = args[0]; + ValueRef dim = args[1]; + ValueRef out_symint = args[2]; + + int64_t dim_val = graph->extract_scalar(dim); + int64_t size_at_dim = graph->size_at(dim_val, in_tensor); + + graph->set_symint(out_symint, static_cast(size_at_dim)); +} + void resize_sym_size_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { (void)args; // Unused parameter - - ValueRef out_symint_ref = extra_args[0]; - ValueRef in_tensor_ref = extra_args[1]; - - int64_t dim = graph->extract_scalar(extra_args[2]); - int64_t size_at_dim = graph->size_at(dim, in_tensor_ref); - - graph->set_symint(out_symint_ref, static_cast(size_at_dim)); + sym_size_impl(graph, extra_args); } /* @@ -32,21 +40,50 @@ void resize_sym_size_node( * specified dimension. */ void sym_size_int(ComputeGraph& graph, const std::vector& args) { - ValueRef in_tensor = args[0]; - ValueRef dim = args[1]; - ValueRef out_symint = args[2]; + sym_size_impl(&graph, args); - int64_t dim_val = graph.extract_scalar(dim); + graph.execute_nodes().emplace_back( + new ExecuteNode(resize_sym_size_node, args)); +} + +// +// binary operators +// + +void sym_add_impl(ComputeGraph* graph, const std::vector& args) { + ValueRef a = args[0]; + ValueRef b = args[1]; + ValueRef out = args[2]; + + int32_t a_val = graph->read_symint(a); + int32_t b_val = graph->read_symint(b); + int32_t result = a_val + b_val; + + graph->set_symint(out, result); +} + +void resize_sym_add_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)args; // Unused parameter + sym_add_impl(graph, extra_args); +} - int64_t size_at_dim = graph.size_at(dim_val, in_tensor); - graph.set_symint(out_symint, static_cast(size_at_dim)); +/* + * This operator takes two symints as inputs and produces a symint as output. + * The output symint's value is the sum of the two input symints. + */ +void sym_add(ComputeGraph& graph, const std::vector& args) { + sym_add_impl(&graph, args); graph.execute_nodes().emplace_back( - new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim})); + new ExecuteNode(resize_sym_add_node, args)); } REGISTER_OPERATORS { VK_REGISTER_OP(sym_size.int, sym_size_int); + VK_REGISTER_OP(add, sym_add); } } // namespace vkcompute diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 76b9240463b..62a15a22ede 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -340,17 +340,21 @@ def process_call_function_node(self, node) -> None: self.seen_ops.add(node.target) - for i, schema_arg in enumerate(node.target._schema.arguments): - if not schema_arg.kwarg_only and i < len(node.args): - function_arg = node.args[i] - elif schema_arg.name in node.kwargs: - function_arg = node.kwargs[schema_arg.name] - else: - function_arg = schema_arg.default_value - - # Create a Value for each function argument. If the argument has been - # previously encountered, then use the existing Value id. - operator_call_args.append(self.get_or_create_value_for(function_arg)) + if hasattr(node.target, "_schema"): + for i, schema_arg in enumerate(node.target._schema.arguments): + if not schema_arg.kwarg_only and i < len(node.args): + function_arg = node.args[i] + elif schema_arg.name in node.kwargs: + function_arg = node.kwargs[schema_arg.name] + else: + function_arg = schema_arg.default_value + + # Create a Value for each function argument. If the argument has been + # previously encountered, then use the existing Value id. + operator_call_args.append(self.get_or_create_value_for(function_arg)) + else: + for i, arg_node in enumerate(node.args): + operator_call_args.append(self.get_or_create_value_for(arg_node)) # Add output node operator_call_args.append(self.create_node_value(node)) @@ -397,6 +401,9 @@ def process_node(self, node: Node, call_node_debug_hdl: int) -> None: raise AssertionError(f"Unsupported node op: {node.op}") def build_graph(self) -> vk_graph_schema.VkGraph: + print("Building graph...") + print(self.program.graph_module.graph) + call_node_debug_hdl = 0 for node in self.program.graph_module.graph.nodes: self.process_node(node, call_node_debug_hdl)