Skip to content

Commit 344bc14

Browse files
committed
[ET-VK] Add support for operator.add for symint ops
Pull Request resolved: #11257 ## Changes * Add an implementation for operator.add which add symbolic integers. ## Motivation Support executing llama models with dynamic shapes. This operator shows up when exporting with dynamic shapes. ghstack-source-id: 287878312 @exported-using-ghexport Differential Revision: [D75238029](https://our.internmc.facebook.com/intern/diff/D75238029/)
1 parent 9cfbe8d commit 344bc14

File tree

4 files changed

+78
-27
lines changed

4 files changed

+78
-27
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323

2424
from executorch.exir.pass_base import ExportPass, PassResult
25+
from executorch.exir.tensor import TensorSpec
2526

2627
logger: logging.Logger = logging.getLogger("")
2728
logger.setLevel(logging.INFO)
@@ -51,8 +52,12 @@ def insert_transition_node(
5152
exir_ops.edge.aten.clone.default,
5253
(arg,),
5354
)
55+
print(arg)
56+
print(arg.meta["val"])
57+
print(arg.meta["spec"])
5458
clone_node.meta["val"] = arg.meta["val"]
55-
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
59+
# clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
60+
clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"])
5661
clone_node.meta["spec"].const = False
5762
set_memory_metadata(clone_node, storage, layout)
5863
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def update_features_impl(op: OpKey):
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231231
# Symbolic integer ops
232232
torch.ops.aten.sym_size.int,
233+
operator.add,
233234
]
234235
)
235236
def register_ephemeral_op(features: OpFeatures):

backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,29 @@
1111

1212
namespace vkcompute {
1313

14+
//
15+
// sym_size
16+
//
17+
18+
void sym_size_impl(
19+
ComputeGraph* graph,
20+
const std::vector<ValueRef>& args) {
21+
const ValueRef in_tensor = args.at(0);
22+
const ValueRef dim = args.at(1);
23+
const ValueRef out_symint = args.at(2);
24+
25+
const int64_t dim_val = graph->extract_scalar<int64_t>(dim);
26+
const int64_t size_at_dim = graph->size_at<int64_t>(dim_val, in_tensor);
27+
28+
graph->set_symint(out_symint, static_cast<int32_t>(size_at_dim));
29+
}
30+
1431
void resize_sym_size_node(
1532
ComputeGraph* graph,
1633
const std::vector<ArgGroup>& args,
1734
const std::vector<ValueRef>& extra_args) {
1835
(void)args; // Unused parameter
19-
20-
ValueRef out_symint_ref = extra_args[0];
21-
ValueRef in_tensor_ref = extra_args[1];
22-
23-
int64_t dim = graph->extract_scalar<int64_t>(extra_args[2]);
24-
int64_t size_at_dim = graph->size_at<int64_t>(dim, in_tensor_ref);
25-
26-
graph->set_symint(out_symint_ref, static_cast<int32_t>(size_at_dim));
36+
sym_size_impl(graph, extra_args);
2737
}
2838

2939
/*
@@ -32,21 +42,52 @@ void resize_sym_size_node(
3242
* specified dimension.
3343
*/
3444
void sym_size_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
35-
ValueRef in_tensor = args[0];
36-
ValueRef dim = args[1];
37-
ValueRef out_symint = args[2];
45+
sym_size_impl(&graph, args);
46+
47+
graph.execute_nodes().emplace_back(
48+
new ExecuteNode(resize_sym_size_node, args));
49+
}
3850

39-
int64_t dim_val = graph.extract_scalar<int64_t>(dim);
51+
//
52+
// binary operators
53+
//
4054

41-
int64_t size_at_dim = graph.size_at<int64_t>(dim_val, in_tensor);
42-
graph.set_symint(out_symint, static_cast<int32_t>(size_at_dim));
55+
void sym_add_impl(
56+
ComputeGraph* graph,
57+
const std::vector<ValueRef>& args) {
58+
const ValueRef a = args.at(0);
59+
const ValueRef b = args.at(1);
60+
const ValueRef out = args.at(2);
61+
62+
const int32_t a_val = graph->read_symint(a);
63+
const int32_t b_val = graph->read_symint(b);
64+
const int32_t result = a_val + b_val;
65+
66+
graph->set_symint(out, result);
67+
}
68+
69+
void resize_sym_add_node(
70+
ComputeGraph* graph,
71+
const std::vector<ArgGroup>& args,
72+
const std::vector<ValueRef>& extra_args) {
73+
(void)args; // Unused parameter
74+
sym_add_impl(graph, extra_args);
75+
}
76+
77+
/*
78+
* This operator takes two symints as inputs and produces a symint as output.
79+
* The output symint's value is the sum of the two input symints.
80+
*/
81+
void sym_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
82+
sym_add_impl(&graph, args);
4383

4484
graph.execute_nodes().emplace_back(
45-
new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim}));
85+
new ExecuteNode(resize_sym_add_node, args));
4686
}
4787

4888
REGISTER_OPERATORS {
4989
VK_REGISTER_OP(sym_size.int, sym_size_int);
90+
VK_REGISTER_OP(add, sym_add);
5091
}
5192

5293
} // namespace vkcompute

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,17 +340,21 @@ def process_call_function_node(self, node) -> None:
340340

341341
self.seen_ops.add(node.target)
342342

343-
for i, schema_arg in enumerate(node.target._schema.arguments):
344-
if not schema_arg.kwarg_only and i < len(node.args):
345-
function_arg = node.args[i]
346-
elif schema_arg.name in node.kwargs:
347-
function_arg = node.kwargs[schema_arg.name]
348-
else:
349-
function_arg = schema_arg.default_value
350-
351-
# Create a Value for each function argument. If the argument has been
352-
# previously encountered, then use the existing Value id.
353-
operator_call_args.append(self.get_or_create_value_for(function_arg))
343+
if hasattr(node.target, "_schema"):
344+
for i, schema_arg in enumerate(node.target._schema.arguments):
345+
if not schema_arg.kwarg_only and i < len(node.args):
346+
function_arg = node.args[i]
347+
elif schema_arg.name in node.kwargs:
348+
function_arg = node.kwargs[schema_arg.name]
349+
else:
350+
function_arg = schema_arg.default_value
351+
352+
# Create a Value for each function argument. If the argument has been
353+
# previously encountered, then use the existing Value id.
354+
operator_call_args.append(self.get_or_create_value_for(function_arg))
355+
else:
356+
for i, arg_node in enumerate(node.args):
357+
operator_call_args.append(self.get_or_create_value_for(arg_node))
354358

355359
# Add output node
356360
operator_call_args.append(self.create_node_value(node))

0 commit comments

Comments
 (0)