Skip to content

Commit 8e064cd

Browse files
committed
[ET-VK] Add support for binary symint ops
Pull Request resolved: #11257 ## Changes * Add an implementation for binary operators which add symbolic integers. ## Motivation Support executing llama models with dynamic shapes. This operator shows up when exporting with dynamic shapes. ghstack-source-id: 287234982 @exported-using-ghexport Differential Revision: [D75238029](https://our.internmc.facebook.com/intern/diff/D75238029/)
1 parent 408970c commit 8e064cd

File tree

4 files changed

+74
-27
lines changed

4 files changed

+74
-27
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323

2424
from executorch.exir.pass_base import ExportPass, PassResult
25+
from executorch.exir.tensor import TensorSpec
2526

2627
logger: logging.Logger = logging.getLogger("")
2728
logger.setLevel(logging.INFO)
@@ -51,8 +52,12 @@ def insert_transition_node(
5152
exir_ops.edge.aten.clone.default,
5253
(arg,),
5354
)
55+
print(arg)
56+
print(arg.meta["val"])
57+
print(arg.meta["spec"])
5458
clone_node.meta["val"] = arg.meta["val"]
55-
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
59+
# clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
60+
clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"])
5661
clone_node.meta["spec"].const = False
5762
set_memory_metadata(clone_node, storage, layout)
5863
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def update_features_impl(op: OpKey):
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231231
# Symbolic integer ops
232232
torch.ops.aten.sym_size.int,
233+
operator.add,
233234
]
234235
)
235236
def register_ephemeral_op(features: OpFeatures):

backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,27 @@
1111

1212
namespace vkcompute {
1313

14+
//
15+
// sym_size
16+
//
17+
18+
void sym_size_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
19+
ValueRef in_tensor = args[0];
20+
ValueRef dim = args[1];
21+
ValueRef out_symint = args[2];
22+
23+
int64_t dim_val = graph->extract_scalar<int64_t>(dim);
24+
int64_t size_at_dim = graph->size_at<int64_t>(dim_val, in_tensor);
25+
26+
graph->set_symint(out_symint, static_cast<int32_t>(size_at_dim));
27+
}
28+
1429
void resize_sym_size_node(
1530
ComputeGraph* graph,
1631
const std::vector<ArgGroup>& args,
1732
const std::vector<ValueRef>& extra_args) {
1833
(void)args; // Unused parameter
19-
20-
ValueRef out_symint_ref = extra_args[0];
21-
ValueRef in_tensor_ref = extra_args[1];
22-
23-
int64_t dim = graph->extract_scalar<int64_t>(extra_args[2]);
24-
int64_t size_at_dim = graph->size_at<int64_t>(dim, in_tensor_ref);
25-
26-
graph->set_symint(out_symint_ref, static_cast<int32_t>(size_at_dim));
34+
sym_size_impl(graph, extra_args);
2735
}
2836

2937
/*
@@ -32,21 +40,50 @@ void resize_sym_size_node(
3240
* specified dimension.
3341
*/
3442
void sym_size_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
35-
ValueRef in_tensor = args[0];
36-
ValueRef dim = args[1];
37-
ValueRef out_symint = args[2];
43+
sym_size_impl(&graph, args);
3844

39-
int64_t dim_val = graph.extract_scalar<int64_t>(dim);
45+
graph.execute_nodes().emplace_back(
46+
new ExecuteNode(resize_sym_size_node, args));
47+
}
48+
49+
//
50+
// binary operators
51+
//
52+
53+
void sym_add_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
54+
ValueRef a = args[0];
55+
ValueRef b = args[1];
56+
ValueRef out = args[2];
57+
58+
int32_t a_val = graph->read_symint(a);
59+
int32_t b_val = graph->read_symint(b);
60+
int32_t result = a_val + b_val;
61+
62+
graph->set_symint(out, result);
63+
}
64+
65+
void resize_sym_add_node(
66+
ComputeGraph* graph,
67+
const std::vector<ArgGroup>& args,
68+
const std::vector<ValueRef>& extra_args) {
69+
(void)args; // Unused parameter
70+
sym_add_impl(graph, extra_args);
71+
}
4072

41-
int64_t size_at_dim = graph.size_at<int64_t>(dim_val, in_tensor);
42-
graph.set_symint(out_symint, static_cast<int32_t>(size_at_dim));
73+
/*
74+
* This operator takes two symints as inputs and produces a symint as output.
75+
* The output symint's value is the sum of the two input symints.
76+
*/
77+
void sym_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
78+
sym_add_impl(&graph, args);
4379

4480
graph.execute_nodes().emplace_back(
45-
new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim}));
81+
new ExecuteNode(resize_sym_add_node, args));
4682
}
4783

4884
REGISTER_OPERATORS {
4985
VK_REGISTER_OP(sym_size.int, sym_size_int);
86+
VK_REGISTER_OP(add, sym_add);
5087
}
5188

5289
} // namespace vkcompute

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,17 +340,21 @@ def process_call_function_node(self, node) -> None:
340340

341341
self.seen_ops.add(node.target)
342342

343-
for i, schema_arg in enumerate(node.target._schema.arguments):
344-
if not schema_arg.kwarg_only and i < len(node.args):
345-
function_arg = node.args[i]
346-
elif schema_arg.name in node.kwargs:
347-
function_arg = node.kwargs[schema_arg.name]
348-
else:
349-
function_arg = schema_arg.default_value
350-
351-
# Create a Value for each function argument. If the argument has been
352-
# previously encountered, then use the existing Value id.
353-
operator_call_args.append(self.get_or_create_value_for(function_arg))
343+
if hasattr(node.target, "_schema"):
344+
for i, schema_arg in enumerate(node.target._schema.arguments):
345+
if not schema_arg.kwarg_only and i < len(node.args):
346+
function_arg = node.args[i]
347+
elif schema_arg.name in node.kwargs:
348+
function_arg = node.kwargs[schema_arg.name]
349+
else:
350+
function_arg = schema_arg.default_value
351+
352+
# Create a Value for each function argument. If the argument has been
353+
# previously encountered, then use the existing Value id.
354+
operator_call_args.append(self.get_or_create_value_for(function_arg))
355+
else:
356+
for i, arg_node in enumerate(node.args):
357+
operator_call_args.append(self.get_or_create_value_for(arg_node))
354358

355359
# Add output node
356360
operator_call_args.append(self.create_node_value(node))

0 commit comments

Comments
 (0)