@@ -25,10 +25,11 @@ using utils::uvec4;
2525namespace  {
2626
2727void  check_args (
28-     const  api::vTensor& in,
29-     const  std::vector<int64_t >& permute_dims,
30-     const  api::vTensor& out) {
31-   VK_CHECK_COND (check_same_packed_dim (in, out));
28+     ComputeGraph& graph,
29+     const  ValueRef in,
30+     const  ValueRef permute_dims,
31+     const  ValueRef out) {
32+   VK_CHECK_COND (check_same_packed_dim (graph, in, out));
3233
3334  //  This implementation doesn't not requires the input tensor to have the same
3435  //  dim size as the argument. The code will work as long as the input tensor's
@@ -38,40 +39,93 @@ void check_args(
3839
3940} //  namespace
4041
42+ void  resize_permute_node (
43+     ComputeGraph* graph,
44+     const  std::vector<ArgGroup>& args,
45+     const  std::vector<ValueRef>& resize_args) {
46+   const  ValueRef out = args[0 ].refs [0 ];
47+   const  ValueRef in = args[1 ].refs [0 ];
48+ 
49+   const  std::vector<int64_t > in_sizes = graph->sizes_of (in);
50+   const  std::vector<int64_t > out_sizes = graph->sizes_of (out);
51+ 
52+   const  std::vector<int64_t > permute_dims =
53+       graph->extract_int_or_symint_list (resize_args[0 ]);
54+ 
55+   if  (in_sizes.size () == out_sizes.size () &&
56+       in_sizes.size () == permute_dims.size ()) {
57+     std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
58+     const  int64_t  out_ndim = std::max (in_sizes.size (), out_sizes.size ());
59+     for  (int  i = 0 ; i < out_ndim; i++) {
60+       const  int64_t  permute_dim = permute_dims.at (i);
61+       new_out_sizes.at (i) = in_sizes.at (permute_dim);
62+     }
63+     graph->virtual_resize (out, new_out_sizes);
64+   }
65+   //  Case where permute is being used to implement squeeze
66+   else  if  (
67+       in_sizes.size () > out_sizes.size () &&
68+       in_sizes.size () == permute_dims.size ()) {
69+     std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
70+     const  int  offset = in_sizes.size () - out_sizes.size ();
71+     for  (int  i = 0 ; i < out_sizes.size (); i++) {
72+       const  int64_t  permute_dim = permute_dims.at (i + offset);
73+       new_out_sizes.at (i) = in_sizes.at (permute_dim);
74+     }
75+     graph->virtual_resize (out, new_out_sizes);
76+   }
77+   //  Case where Permute is being used to implement unsqueeze
78+   else  if  (
79+       in_sizes.size () < out_sizes.size () &&
80+       out_sizes.size () == permute_dims.size ()) {
81+     std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
82+     const  int  offset = out_sizes.size () - in_sizes.size ();
83+     for  (int  i = 0 ; i < out_sizes.size (); i++) {
84+       int64_t  permute_dim = permute_dims.at (i) - offset;
85+       if  (permute_dim >= 0 ) {
86+         new_out_sizes.at (i) = in_sizes.at (permute_dim);
87+       }
88+     }
89+     graph->virtual_resize (out, new_out_sizes);
90+   } else  {
91+     VK_THROW (" Invalid permute dims" 
92+   }
93+ }
94+ 
4195void  add_permute_node (
4296    ComputeGraph& graph,
43-     ValueRef in,
44-     const  std::vector<int64_t >& permute_dims,
45-     ValueRef out) {
46-   vTensorPtr t_in = graph.get_tensor (in);
47-   vTensorPtr t_out = graph.get_tensor (out);
48- 
49-   check_args (*t_in, permute_dims, *t_out);
97+     const  ValueRef in,
98+     const  ValueRef permute_dims,
99+     const  ValueRef out) {
100+   check_args (graph, in, permute_dims, out);
50101
51102  ivec4 out_dims{0 , 1 , 2 , 3 };
52103
53104  //  Special cases of squeeze/unsqueeze. Because the input dim size can be
54-   //  different with output dim size. So pick t_in->dim( ) if squeeze, and
55-   //  t_out->dim( ) if unsqueeze to create parameter for permute.
56-   int64_t  out_ndim = std::max (t_in-> dim ( ), t_out-> dim ( ));
105+   //  different with output dim size. So pick graph.dim_of(in ) if squeeze, and
106+   //  graph.dim_of(out ) if unsqueeze to create parameter for permute.
107+   const   int64_t  out_ndim = std::max (graph. dim_of (in ), graph. dim_of (out ));
57108  std::vector<bool > seen (out_ndim);
58-   for  (int  i = 0 ; i < out_ndim; i++) {
59-     int64_t  permute_dim = permute_dims[i];
60-     VK_CHECK_COND (
61-         !seen[permute_dim], " Argument dim " "   is repeated" 
62-     seen[permute_dim] = true ;
63- 
64-     out_dims[(4u  - out_ndim) + i] = permute_dim + (4  - out_ndim);
109+   {
110+     IntListPtr permute_dims_ptr = graph.get_int_list (permute_dims);
111+     for  (int  i = 0 ; i < out_ndim; i++) {
112+       int64_t  permute_dim = permute_dims_ptr->at (i);
113+       VK_CHECK_COND (
114+           !seen[permute_dim], " Argument dim " "   is repeated" 
115+       seen[permute_dim] = true ;
116+ 
117+       out_dims[(4u  - out_ndim) + i] = permute_dim + (4  - out_ndim);
118+     }
65119  }
66120
67121  std::string kernel_name = " permute" 
68122  kernel_name.reserve (kShaderNameReserve );
69-   add_dtype_suffix (kernel_name, *t_out );
123+   add_dtype_suffix (kernel_name, graph. dtype_of (out) );
70124
71-   int32_t  out_channels = dim_at<kChannel4D >(t_out-> sizes ( ));
72-   int32_t  in_channels = dim_at<kChannel4D >(t_in-> sizes ( ));
125+   const   int32_t  out_channels = dim_at<kChannel4D >(graph. sizes_of (out ));
126+   const   int32_t  in_channels = dim_at<kChannel4D >(graph. sizes_of (in ));
73127
74-   const  auto  packed_dim = graph.packed_dim_of (in);
128+   const  int32_t  packed_dim = graph.packed_dim_of (in);
75129  ivec2 channel_info = {out_channels, in_channels};
76130  if  (packed_dim == WHCN::kChannelsDim ) {
77131    channel_info[0 ] = utils::align_up_4 (channel_info[0 ]);
@@ -95,19 +149,9 @@ void add_permute_node(
95149      //  Specialization Constants
96150      spec_vars,
97151      //  Resize Args
98-       {},
152+       {permute_dims },
99153      //  Resizing Logic
100-       nullptr ));
101- }
102- 
103- void  add_permute_node (
104-     ComputeGraph& graph,
105-     ValueRef in,
106-     ValueRef permute_dims_ref,
107-     ValueRef out) {
108-   IntListPtr permute_dims = graph.get_int_list (permute_dims_ref);
109- 
110-   add_permute_node (graph, in, *permute_dims, out);
154+       resize_permute_node));
111155}
112156
113157void  permute (ComputeGraph& graph, const  std::vector<ValueRef>& args) {
0 commit comments