1111
1212namespace  vkcompute  {
1313
14+ // 
15+ //  sym_size
16+ // 
17+ 
18+ void  sym_size_impl (
19+     ComputeGraph* graph,
20+     const  std::vector<ValueRef>& args) {
21+   const  ValueRef in_tensor = args.at (0 );
22+   const  ValueRef dim = args.at (1 );
23+   const  ValueRef out_symint = args.at (2 );
24+ 
25+   const  int64_t  dim_val = graph->extract_scalar <int64_t >(dim);
26+   const  int64_t  size_at_dim = graph->size_at <int64_t >(dim_val, in_tensor);
27+ 
28+   graph->set_symint (out_symint, static_cast <int32_t >(size_at_dim));
29+ }
30+ 
1431void  resize_sym_size_node (
1532    ComputeGraph* graph,
1633    const  std::vector<ArgGroup>& args,
1734    const  std::vector<ValueRef>& extra_args) {
1835  (void )args; //  Unused parameter
19- 
20-   ValueRef out_symint_ref = extra_args[0 ];
21-   ValueRef in_tensor_ref = extra_args[1 ];
22- 
23-   int64_t  dim = graph->extract_scalar <int64_t >(extra_args[2 ]);
24-   int64_t  size_at_dim = graph->size_at <int64_t >(dim, in_tensor_ref);
25- 
26-   graph->set_symint (out_symint_ref, static_cast <int32_t >(size_at_dim));
36+   sym_size_impl (graph, extra_args);
2737}
2838
2939/* 
@@ -32,21 +42,52 @@ void resize_sym_size_node(
3242 * specified dimension. 
3343 */  
3444void  sym_size_int (ComputeGraph& graph, const  std::vector<ValueRef>& args) {
35-   ValueRef in_tensor = args[0 ];
36-   ValueRef dim = args[1 ];
37-   ValueRef out_symint = args[2 ];
45+   sym_size_impl (&graph, args);
46+ 
47+   graph.execute_nodes ().emplace_back (
48+       new  ExecuteNode (resize_sym_size_node, args));
49+ }
3850
39-   int64_t  dim_val = graph.extract_scalar <int64_t >(dim);
51+ // 
52+ //  binary operators
53+ // 
4054
41-   int64_t  size_at_dim = graph.size_at <int64_t >(dim_val, in_tensor);
42-   graph.set_symint (out_symint, static_cast <int32_t >(size_at_dim));
55+ void  sym_add_impl (
56+     ComputeGraph* graph,
57+     const  std::vector<ValueRef>& args) {
58+   const  ValueRef a = args.at (0 );
59+   const  ValueRef b = args.at (1 );
60+   const  ValueRef out = args.at (2 );
61+ 
62+   const  int32_t  a_val = graph->read_symint (a);
63+   const  int32_t  b_val = graph->read_symint (b);
64+   const  int32_t  result = a_val + b_val;
65+ 
66+   graph->set_symint (out, result);
67+ }
68+ 
69+ void  resize_sym_add_node (
70+     ComputeGraph* graph,
71+     const  std::vector<ArgGroup>& args,
72+     const  std::vector<ValueRef>& extra_args) {
73+   (void )args; //  Unused parameter
74+   sym_add_impl (graph, extra_args);
75+ }
76+ 
77+ /* 
78+  * This operator takes two symints as inputs and produces a symint as output. 
79+  * The output symint's value is the sum of the two input symints. 
80+  */  
81+ void  sym_add (ComputeGraph& graph, const  std::vector<ValueRef>& args) {
82+   sym_add_impl (&graph, args);
4383
4484  graph.execute_nodes ().emplace_back (
45-       new  ExecuteNode (resize_sym_size_node, {out_symint, in_tensor, dim} ));
85+       new  ExecuteNode (resize_sym_add_node, args ));
4686}
4787
4888REGISTER_OPERATORS {
4989  VK_REGISTER_OP (sym_size.int , sym_size_int);
90+   VK_REGISTER_OP (add, sym_add);
5091}
5192
5293} //  namespace vkcompute
0 commit comments