1111
1212namespace vkcompute {
1313
14+ //
15+ // sym_size
16+ //
17+
18+ void sym_size_impl (ComputeGraph* graph, const std::vector<ValueRef>& args) {
19+ const ValueRef in_tensor = args.at (0 );
20+ const ValueRef dim = args.at (1 );
21+ const ValueRef out_symint = args.at (2 );
22+
23+ const int64_t dim_val = graph->extract_scalar <int64_t >(dim);
24+ const int64_t size_at_dim = graph->size_at <int64_t >(dim_val, in_tensor);
25+
26+ graph->set_symint (out_symint, static_cast <int32_t >(size_at_dim));
27+ }
28+
1429void resize_sym_size_node (
1530 ComputeGraph* graph,
1631 const std::vector<ArgGroup>& args,
17- const std::vector<ValueRef>& extra_args ) {
32+ const std::vector<ValueRef>& resize_args ) {
1833 (void )args; // Unused parameter
19-
20- ValueRef out_symint_ref = extra_args[0 ];
21- ValueRef in_tensor_ref = extra_args[1 ];
22-
23- int64_t dim = graph->extract_scalar <int64_t >(extra_args[2 ]);
24- int64_t size_at_dim = graph->size_at <int64_t >(dim, in_tensor_ref);
25-
26- graph->set_symint (out_symint_ref, static_cast <int32_t >(size_at_dim));
34+ sym_size_impl (graph, resize_args);
2735}
2836
2937/*
@@ -32,21 +40,50 @@ void resize_sym_size_node(
3240 * specified dimension.
3341 */
3442void sym_size_int (ComputeGraph& graph, const std::vector<ValueRef>& args) {
35- ValueRef in_tensor = args[0 ];
36- ValueRef dim = args[1 ];
37- ValueRef out_symint = args[2 ];
43+ sym_size_impl (&graph, args);
44+
45+ graph.execute_nodes ().emplace_back (
46+ new ExecuteNode (resize_sym_size_node, args));
47+ }
3848
39- int64_t dim_val = graph.extract_scalar <int64_t >(dim);
49+ //
50+ // binary operators
51+ //
4052
41- int64_t size_at_dim = graph.size_at <int64_t >(dim_val, in_tensor);
42- graph.set_symint (out_symint, static_cast <int32_t >(size_at_dim));
53+ void sym_add_impl (ComputeGraph* graph, const std::vector<ValueRef>& args) {
54+ const ValueRef a = args.at (0 );
55+ const ValueRef b = args.at (1 );
56+ const ValueRef out = args.at (2 );
57+
58+ const int32_t a_val = graph->read_symint (a);
59+ const int32_t b_val = graph->read_symint (b);
60+ const int32_t result = a_val + b_val;
61+
62+ graph->set_symint (out, result);
63+ }
64+
65+ void resize_sym_add_node (
66+ ComputeGraph* graph,
67+ const std::vector<ArgGroup>& args,
68+ const std::vector<ValueRef>& resize_args) {
69+ (void )args; // Unused parameter
70+ sym_add_impl (graph, resize_args);
71+ }
72+
73+ /*
74+ * This operator takes two symints as inputs and produces a symint as output.
75+ * The output symint's value is the sum of the two input symints.
76+ */
77+ void sym_add (ComputeGraph& graph, const std::vector<ValueRef>& args) {
78+ sym_add_impl (&graph, args);
4379
4480 graph.execute_nodes ().emplace_back (
45- new ExecuteNode (resize_sym_size_node, {out_symint, in_tensor, dim} ));
81+ new ExecuteNode (resize_sym_add_node, args ));
4682}
4783
4884REGISTER_OPERATORS {
4985 VK_REGISTER_OP (sym_size.int , sym_size_int);
86+ VK_REGISTER_OP (add, sym_add);
5087}
5188
5289} // namespace vkcompute
0 commit comments