@@ -25,10 +25,11 @@ using utils::uvec4;
2525namespace {
2626
2727void check_args (
28- const api::vTensor& in,
29- const std::vector<int64_t >& permute_dims,
30- const api::vTensor& out) {
31- VK_CHECK_COND (check_same_packed_dim (in, out));
28+ ComputeGraph& graph,
29+ const ValueRef in,
30+ const ValueRef permute_dims,
31+ const ValueRef out) {
32+ VK_CHECK_COND (check_same_packed_dim (graph, in, out));
3233
3334 // This implementation doesn't not requires the input tensor to have the same
3435 // dim size as the argument. The code will work as long as the input tensor's
@@ -38,40 +39,93 @@ void check_args(
3839
3940} // namespace
4041
42+ void resize_permute_node (
43+ ComputeGraph* graph,
44+ const std::vector<ArgGroup>& args,
45+ const std::vector<ValueRef>& extra_args) {
46+ ValueRef out = args[0 ].refs [0 ];
47+ ValueRef in = args[1 ].refs [0 ];
48+
49+ std::vector<int64_t > in_sizes = graph->sizes_of (in);
50+ std::vector<int64_t > out_sizes = graph->sizes_of (out);
51+
52+ std::vector<int64_t > permute_dims =
53+ graph->extract_int_or_symint_list (extra_args[0 ]);
54+
55+ if (in_sizes.size () == out_sizes.size () &&
56+ in_sizes.size () == permute_dims.size ()) {
57+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
58+ int64_t out_ndim = std::max (in_sizes.size (), out_sizes.size ());
59+ for (int i = 0 ; i < out_ndim; i++) {
60+ int64_t permute_dim = permute_dims.at (i);
61+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
62+ }
63+ graph->virtual_resize (out, new_out_sizes);
64+ }
65+ // Case where permute is being used to implement squeeze
66+ else if (
67+ in_sizes.size () > out_sizes.size () &&
68+ in_sizes.size () == permute_dims.size ()) {
69+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
70+ int offset = in_sizes.size () - out_sizes.size ();
71+ for (int i = 0 ; i < out_sizes.size (); i++) {
72+ int64_t permute_dim = permute_dims.at (i + offset);
73+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
74+ }
75+ graph->virtual_resize (out, new_out_sizes);
76+ }
77+ // Case where Permute is being used to implement unsqueeze
78+ else if (
79+ in_sizes.size () < out_sizes.size () &&
80+ out_sizes.size () == permute_dims.size ()) {
81+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
82+ int offset = out_sizes.size () - in_sizes.size ();
83+ for (int i = 0 ; i < out_sizes.size (); i++) {
84+ int64_t permute_dim = permute_dims.at (i) - offset;
85+ if (permute_dim >= 0 ) {
86+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
87+ }
88+ }
89+ graph->virtual_resize (out, new_out_sizes);
90+ } else {
91+ VK_THROW (" Invalid permute dims" );
92+ }
93+ }
94+
4195void add_permute_node (
4296 ComputeGraph& graph,
43- ValueRef in,
44- const std::vector<int64_t >& permute_dims,
45- ValueRef out) {
46- vTensorPtr t_in = graph.get_tensor (in);
47- vTensorPtr t_out = graph.get_tensor (out);
48-
49- check_args (*t_in, permute_dims, *t_out);
97+ const ValueRef in,
98+ const ValueRef permute_dims,
99+ const ValueRef out) {
100+ check_args (graph, in, permute_dims, out);
50101
51102 ivec4 out_dims{0 , 1 , 2 , 3 };
52103
53104 // Special cases of squeeze/unsqueeze. Because the input dim size can be
54105 // different with output dim size. So pick t_in->dim() if squeeze, and
55106 // t_out->dim() if unsqueeze to create parameter for permute.
56- int64_t out_ndim = std::max (t_in-> dim ( ), t_out-> dim ( ));
107+ int64_t out_ndim = std::max (graph. dim_of (in ), graph. dim_of (out ));
57108 std::vector<bool > seen (out_ndim);
58- for (int i = 0 ; i < out_ndim; i++) {
59- int64_t permute_dim = permute_dims[i];
60- VK_CHECK_COND (
61- !seen[permute_dim], " Argument dim " , permute_dim, " is repeated" );
62- seen[permute_dim] = true ;
63-
64- out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
109+ {
110+ IntListPtr permute_dims_ptr = graph.get_int_list (permute_dims);
111+ for (int i = 0 ; i < out_ndim; i++) {
112+ int64_t permute_dim = permute_dims_ptr->at (i);
113+ VK_CHECK_COND (
114+ !seen[permute_dim], " Argument dim " , permute_dim, " is repeated" );
115+ seen[permute_dim] = true ;
116+
117+ out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
118+ }
65119 }
66120
67121 std::string kernel_name = " permute" ;
68122 kernel_name.reserve (kShaderNameReserve );
69- add_dtype_suffix (kernel_name, *t_out );
123+ add_dtype_suffix (kernel_name, graph. dtype_of (out) );
70124
71- int32_t out_channels = dim_at<kChannel4D >(t_out-> sizes ( ));
72- int32_t in_channels = dim_at<kChannel4D >(t_in-> sizes ( ));
125+ int32_t out_channels = dim_at<kChannel4D >(graph. sizes_of (out ));
126+ int32_t in_channels = dim_at<kChannel4D >(graph. sizes_of (in ));
73127
74- const auto packed_dim = graph.packed_dim_of (in);
128+ const int32_t packed_dim = graph.packed_dim_of (in);
75129 ivec2 channel_info = {out_channels, in_channels};
76130 if (packed_dim == WHCN::kChannelsDim ) {
77131 channel_info[0 ] = utils::align_up_4 (channel_info[0 ]);
@@ -95,19 +149,9 @@ void add_permute_node(
95149 // Specialization Constants
96150 spec_vars,
97151 // Resize Args
98- {},
152+ {permute_dims },
99153 // Resizing Logic
100- nullptr ));
101- }
102-
103- void add_permute_node (
104- ComputeGraph& graph,
105- ValueRef in,
106- ValueRef permute_dims_ref,
107- ValueRef out) {
108- IntListPtr permute_dims = graph.get_int_list (permute_dims_ref);
109-
110- add_permute_node (graph, in, *permute_dims, out);
154+ resize_permute_node));
111155}
112156
113157void permute (ComputeGraph& graph, const std::vector<ValueRef>& args) {
0 commit comments