Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.tensor import TensorSpec

logger: logging.Logger = logging.getLogger("")
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -51,8 +52,12 @@ def insert_transition_node(
exir_ops.edge.aten.clone.default,
(arg,),
)
print(arg)
print(arg.meta["val"])
print(arg.meta["spec"])
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
# clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def update_features_impl(op: OpKey):
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
]
)
def register_ephemeral_op(features: OpFeatures):
Expand Down
71 changes: 56 additions & 15 deletions backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,29 @@

namespace vkcompute {

//
// sym_size
//

void sym_size_impl(
ComputeGraph* graph,
const std::vector<ValueRef>& args) {
const ValueRef in_tensor = args.at(0);
const ValueRef dim = args.at(1);
const ValueRef out_symint = args.at(2);

const int64_t dim_val = graph->extract_scalar<int64_t>(dim);
const int64_t size_at_dim = graph->size_at<int64_t>(dim_val, in_tensor);

graph->set_symint(out_symint, static_cast<int32_t>(size_at_dim));
}

void resize_sym_size_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)args; // Unused parameter

ValueRef out_symint_ref = extra_args[0];
ValueRef in_tensor_ref = extra_args[1];

int64_t dim = graph->extract_scalar<int64_t>(extra_args[2]);
int64_t size_at_dim = graph->size_at<int64_t>(dim, in_tensor_ref);

graph->set_symint(out_symint_ref, static_cast<int32_t>(size_at_dim));
sym_size_impl(graph, extra_args);
}

/*
Expand All @@ -32,21 +42,52 @@ void resize_sym_size_node(
* specified dimension.
*/
void sym_size_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef in_tensor = args[0];
ValueRef dim = args[1];
ValueRef out_symint = args[2];
sym_size_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_size_node, args));
}

int64_t dim_val = graph.extract_scalar<int64_t>(dim);
//
// binary operators
//

int64_t size_at_dim = graph.size_at<int64_t>(dim_val, in_tensor);
graph.set_symint(out_symint, static_cast<int32_t>(size_at_dim));
void sym_add_impl(
ComputeGraph* graph,
const std::vector<ValueRef>& args) {
const ValueRef a = args.at(0);
const ValueRef b = args.at(1);
const ValueRef out = args.at(2);

const int32_t a_val = graph->read_symint(a);
const int32_t b_val = graph->read_symint(b);
const int32_t result = a_val + b_val;

graph->set_symint(out, result);
}

void resize_sym_add_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)args; // Unused parameter
sym_add_impl(graph, extra_args);
}

/*
* This operator takes two symints as inputs and produces a symint as output.
* The output symint's value is the sum of the two input symints.
*/
void sym_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
sym_add_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim}));
new ExecuteNode(resize_sym_add_node, args));
}

REGISTER_OPERATORS {
VK_REGISTER_OP(sym_size.int, sym_size_int);
VK_REGISTER_OP(add, sym_add);
}

} // namespace vkcompute
26 changes: 15 additions & 11 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,21 @@ def process_call_function_node(self, node) -> None:

self.seen_ops.add(node.target)

for i, schema_arg in enumerate(node.target._schema.arguments):
if not schema_arg.kwarg_only and i < len(node.args):
function_arg = node.args[i]
elif schema_arg.name in node.kwargs:
function_arg = node.kwargs[schema_arg.name]
else:
function_arg = schema_arg.default_value

# Create a Value for each function argument. If the argument has been
# previously encountered, then use the existing Value id.
operator_call_args.append(self.get_or_create_value_for(function_arg))
if hasattr(node.target, "_schema"):
for i, schema_arg in enumerate(node.target._schema.arguments):
if not schema_arg.kwarg_only and i < len(node.args):
function_arg = node.args[i]
elif schema_arg.name in node.kwargs:
function_arg = node.kwargs[schema_arg.name]
else:
function_arg = schema_arg.default_value

# Create a Value for each function argument. If the argument has been
# previously encountered, then use the existing Value id.
operator_call_args.append(self.get_or_create_value_for(function_arg))
else:
for i, arg_node in enumerate(node.args):
operator_call_args.append(self.get_or_create_value_for(arg_node))

# Add output node
operator_call_args.append(self.create_node_value(node))
Expand Down
Loading