Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import logging
from copy import deepcopy
from typing import Any, Optional, Set

import executorch.backends.vulkan.utils as utils
Expand All @@ -22,6 +21,7 @@
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.tensor import TensorSpec

logger: logging.Logger = logging.getLogger("")
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -52,7 +52,7 @@ def insert_transition_node(
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def update_features_impl(op: OpKey):
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
]
)
def register_ephemeral_op(features: OpFeatures):
Expand Down
69 changes: 53 additions & 16 deletions backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,27 @@

namespace vkcompute {

//
// sym_size
//

void sym_size_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
const ValueRef in_tensor = args.at(0);
const ValueRef dim = args.at(1);
const ValueRef out_symint = args.at(2);

const int64_t dim_val = graph->extract_scalar<int64_t>(dim);
const int64_t size_at_dim = graph->size_at<int64_t>(dim_val, in_tensor);

graph->set_symint(out_symint, static_cast<int32_t>(size_at_dim));
}

void resize_sym_size_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
const std::vector<ValueRef>& resize_args) {
(void)args; // Unused parameter

ValueRef out_symint_ref = extra_args[0];
ValueRef in_tensor_ref = extra_args[1];

int64_t dim = graph->extract_scalar<int64_t>(extra_args[2]);
int64_t size_at_dim = graph->size_at<int64_t>(dim, in_tensor_ref);

graph->set_symint(out_symint_ref, static_cast<int32_t>(size_at_dim));
sym_size_impl(graph, resize_args);
}

/*
Expand All @@ -32,21 +40,50 @@ void resize_sym_size_node(
* specified dimension.
*/
void sym_size_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef in_tensor = args[0];
ValueRef dim = args[1];
ValueRef out_symint = args[2];
sym_size_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_size_node, args));
}

int64_t dim_val = graph.extract_scalar<int64_t>(dim);
//
// binary operators
//

int64_t size_at_dim = graph.size_at<int64_t>(dim_val, in_tensor);
graph.set_symint(out_symint, static_cast<int32_t>(size_at_dim));
void sym_add_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
const ValueRef a = args.at(0);
const ValueRef b = args.at(1);
const ValueRef out = args.at(2);

const int32_t a_val = graph->read_symint(a);
const int32_t b_val = graph->read_symint(b);
const int32_t result = a_val + b_val;

graph->set_symint(out, result);
}

void resize_sym_add_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args; // Unused parameter
sym_add_impl(graph, resize_args);
}

/*
* This operator takes two symints as inputs and produces a symint as output.
* The output symint's value is the sum of the two input symints.
*/
void sym_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
sym_add_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim}));
new ExecuteNode(resize_sym_add_node, args));
}

REGISTER_OPERATORS {
VK_REGISTER_OP(sym_size.int, sym_size_int);
VK_REGISTER_OP(add, sym_add);
}

} // namespace vkcompute
26 changes: 15 additions & 11 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,21 @@ def process_call_function_node(self, node) -> None:

self.seen_ops.add(node.target)

for i, schema_arg in enumerate(node.target._schema.arguments):
if not schema_arg.kwarg_only and i < len(node.args):
function_arg = node.args[i]
elif schema_arg.name in node.kwargs:
function_arg = node.kwargs[schema_arg.name]
else:
function_arg = schema_arg.default_value

# Create a Value for each function argument. If the argument has been
# previously encountered, then use the existing Value id.
operator_call_args.append(self.get_or_create_value_for(function_arg))
if hasattr(node.target, "_schema"):
for i, schema_arg in enumerate(node.target._schema.arguments):
if not schema_arg.kwarg_only and i < len(node.args):
function_arg = node.args[i]
elif schema_arg.name in node.kwargs:
function_arg = node.kwargs[schema_arg.name]
else:
function_arg = schema_arg.default_value

# Create a Value for each function argument. If the argument has been
# previously encountered, then use the existing Value id.
operator_call_args.append(self.get_or_create_value_for(function_arg))
else:
for _, arg_node in enumerate(node.args):
operator_call_args.append(self.get_or_create_value_for(arg_node))

# Add output node
operator_call_args.append(self.create_node_value(node))
Expand Down
Loading