diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 8983ca67752..836a0c6ef7d 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import logging -from copy import deepcopy from typing import Any, Optional, Set import executorch.backends.vulkan.utils as utils @@ -22,6 +21,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.tensor import TensorSpec logger: logging.Logger = logging.getLogger("") logger.setLevel(logging.INFO) @@ -52,7 +52,7 @@ def insert_transition_node( (arg,), ) clone_node.meta["val"] = arg.meta["val"] - clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"]) clone_node.meta["spec"].const = False set_memory_metadata(clone_node, storage, layout) arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 2aa940dcc4b..3c1f1eb40dc 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -230,6 +230,7 @@ def update_features_impl(op: OpKey): exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # Symbolic integer ops torch.ops.aten.sym_size.int, + operator.add, ] ) def register_ephemeral_op(features: OpFeatures): diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp index 0705b53d394..f07522d2578 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -11,19 +11,27 @@ namespace vkcompute { +// +// sym_size +// + +void sym_size_impl(ComputeGraph* graph, const std::vector& args) { + const ValueRef in_tensor = args.at(0); + const ValueRef dim = args.at(1); + const ValueRef out_symint = args.at(2); + + const int64_t dim_val = graph->extract_scalar(dim); + const int64_t size_at_dim = graph->size_at(dim_val, in_tensor); + + graph->set_symint(out_symint, static_cast(size_at_dim)); +} + void resize_sym_size_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { + const std::vector& resize_args) { (void)args; // Unused parameter - - ValueRef out_symint_ref = extra_args[0]; - ValueRef in_tensor_ref = extra_args[1]; - - int64_t dim = graph->extract_scalar(extra_args[2]); - int64_t size_at_dim = graph->size_at(dim, in_tensor_ref); - - graph->set_symint(out_symint_ref, static_cast(size_at_dim)); + sym_size_impl(graph, resize_args); } /* @@ -32,21 +40,50 @@ void resize_sym_size_node( * specified dimension. */ void sym_size_int(ComputeGraph& graph, const std::vector& args) { - ValueRef in_tensor = args[0]; - ValueRef dim = args[1]; - ValueRef out_symint = args[2]; + sym_size_impl(&graph, args); + + graph.execute_nodes().emplace_back( + new ExecuteNode(resize_sym_size_node, args)); +} - int64_t dim_val = graph.extract_scalar(dim); +// +// binary operators +// - int64_t size_at_dim = graph.size_at(dim_val, in_tensor); - graph.set_symint(out_symint, static_cast(size_at_dim)); +void sym_add_impl(ComputeGraph* graph, const std::vector& args) { + const ValueRef a = args.at(0); + const ValueRef b = args.at(1); + const ValueRef out = args.at(2); + + const int32_t a_val = graph->read_symint(a); + const int32_t b_val = graph->read_symint(b); + const int32_t result = a_val + b_val; + + graph->set_symint(out, result); +} + +void resize_sym_add_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)args; // Unused parameter + sym_add_impl(graph, resize_args); +} + +/* + * This operator takes two symints as inputs and produces a symint as output. + * The output symint's value is the sum of the two input symints. + */ +void sym_add(ComputeGraph& graph, const std::vector& args) { + sym_add_impl(&graph, args); graph.execute_nodes().emplace_back( - new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim})); + new ExecuteNode(resize_sym_add_node, args)); } REGISTER_OPERATORS { VK_REGISTER_OP(sym_size.int, sym_size_int); + VK_REGISTER_OP(add, sym_add); } } // namespace vkcompute diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 76b9240463b..d21d33b75da 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -340,17 +340,21 @@ def process_call_function_node(self, node) -> None: self.seen_ops.add(node.target) - for i, schema_arg in enumerate(node.target._schema.arguments): - if not schema_arg.kwarg_only and i < len(node.args): - function_arg = node.args[i] - elif schema_arg.name in node.kwargs: - function_arg = node.kwargs[schema_arg.name] - else: - function_arg = schema_arg.default_value - - # Create a Value for each function argument. If the argument has been - # previously encountered, then use the existing Value id. - operator_call_args.append(self.get_or_create_value_for(function_arg)) + if hasattr(node.target, "_schema"): + for i, schema_arg in enumerate(node.target._schema.arguments): + if not schema_arg.kwarg_only and i < len(node.args): + function_arg = node.args[i] + elif schema_arg.name in node.kwargs: + function_arg = node.kwargs[schema_arg.name] + else: + function_arg = schema_arg.default_value + + # Create a Value for each function argument. If the argument has been + # previously encountered, then use the existing Value id. + operator_call_args.append(self.get_or_create_value_for(function_arg)) + else: + for _, arg_node in enumerate(node.args): + operator_call_args.append(self.get_or_create_value_for(arg_node)) # Add output node operator_call_args.append(self.create_node_value(node))