1111
1212namespace  vkcompute  {
1313
14+ // 
15+ //  sym_size
16+ // 
17+ 
18+ void  sym_size_impl (ComputeGraph* graph, const  std::vector<ValueRef>& args) {
19+   ValueRef in_tensor = args[0 ];
20+   ValueRef dim = args[1 ];
21+   ValueRef out_symint = args[2 ];
22+ 
23+   int64_t  dim_val = graph->extract_scalar <int64_t >(dim);
24+   int64_t  size_at_dim = graph->size_at <int64_t >(dim_val, in_tensor);
25+ 
26+   graph->set_symint (out_symint, static_cast <int32_t >(size_at_dim));
27+ }
28+ 
1429void  resize_sym_size_node (
1530    ComputeGraph* graph,
1631    const  std::vector<ArgGroup>& args,
1732    const  std::vector<ValueRef>& extra_args) {
1833  (void )args; //  Unused parameter
19- 
20-   ValueRef out_symint_ref = extra_args[0 ];
21-   ValueRef in_tensor_ref = extra_args[1 ];
22- 
23-   int64_t  dim = graph->extract_scalar <int64_t >(extra_args[2 ]);
24-   int64_t  size_at_dim = graph->size_at <int64_t >(dim, in_tensor_ref);
25- 
26-   graph->set_symint (out_symint_ref, static_cast <int32_t >(size_at_dim));
34+   sym_size_impl (graph, extra_args);
2735}
2836
2937/* 
@@ -32,21 +40,50 @@ void resize_sym_size_node(
3240 * specified dimension. 
3341 */  
3442void  sym_size_int (ComputeGraph& graph, const  std::vector<ValueRef>& args) {
35-   ValueRef in_tensor = args[0 ];
36-   ValueRef dim = args[1 ];
37-   ValueRef out_symint = args[2 ];
43+   sym_size_impl (&graph, args);
3844
39-   int64_t  dim_val = graph.extract_scalar <int64_t >(dim);
45+   graph.execute_nodes ().emplace_back (
46+       new  ExecuteNode (resize_sym_size_node, args));
47+ }
48+ 
49+ // 
50+ //  binary operators
51+ // 
52+ 
53+ void  sym_add_impl (ComputeGraph* graph, const  std::vector<ValueRef>& args) {
54+   ValueRef a = args[0 ];
55+   ValueRef b = args[1 ];
56+   ValueRef out = args[2 ];
57+ 
58+   int32_t  a_val = graph->read_symint (a);
59+   int32_t  b_val = graph->read_symint (b);
60+   int32_t  result = a_val + b_val;
61+ 
62+   graph->set_symint (out, result);
63+ }
64+ 
65+ void  resize_sym_add_node (
66+     ComputeGraph* graph,
67+     const  std::vector<ArgGroup>& args,
68+     const  std::vector<ValueRef>& extra_args) {
69+   (void )args; //  Unused parameter
70+   sym_add_impl (graph, extra_args);
71+ }
4072
41-   int64_t  size_at_dim = graph.size_at <int64_t >(dim_val, in_tensor);
42-   graph.set_symint (out_symint, static_cast <int32_t >(size_at_dim));
73+ /* 
74+  * This operator takes two symints as inputs and produces a symint as output. 
75+  * The output symint's value is the sum of the two input symints. 
76+  */  
77+ void  sym_add (ComputeGraph& graph, const  std::vector<ValueRef>& args) {
78+   sym_add_impl (&graph, args);
4379
4480  graph.execute_nodes ().emplace_back (
45-       new  ExecuteNode (resize_sym_size_node, {out_symint, in_tensor, dim} ));
81+       new  ExecuteNode (resize_sym_add_node, args ));
4682}
4783
4884REGISTER_OPERATORS {
4985  VK_REGISTER_OP (sym_size.int , sym_size_int);
86+   VK_REGISTER_OP (add, sym_add);
5087}
5188
5289} //  namespace vkcompute
0 commit comments