@@ -25,10 +25,12 @@ using utils::uvec4;
2525namespace {
2626
2727void check_args (
28- const api::vTensor& in,
29- const std::vector<int64_t >& permute_dims,
30- const api::vTensor& out) {
31- VK_CHECK_COND (check_same_packed_dim (in, out));
28+ ComputeGraph& graph,
29+ const ValueRef in,
30+ const ValueRef permute_dims,
31+ const ValueRef out) {
32+ (void )permute_dims;
33+ VK_CHECK_COND (check_same_packed_dim (graph, in, out));
3234
3335 // This implementation doesn't not requires the input tensor to have the same
3436 // dim size as the argument. The code will work as long as the input tensor's
@@ -38,40 +40,94 @@ void check_args(
3840
3941} // namespace
4042
43+ void resize_permute_node (
44+ ComputeGraph* graph,
45+ const std::vector<ArgGroup>& args,
46+ const std::vector<ValueRef>& resize_args) {
47+ const ValueRef out = args[0 ].refs [0 ];
48+ const ValueRef in = args[1 ].refs [0 ];
49+
50+ const std::vector<int64_t > in_sizes = graph->sizes_of (in);
51+ const std::vector<int64_t > out_sizes = graph->sizes_of (out);
52+
53+ const std::vector<int64_t > permute_dims =
54+ graph->extract_int_or_symint_list (resize_args[0 ]);
55+
56+ if (in_sizes.size () == out_sizes.size () &&
57+ in_sizes.size () == permute_dims.size ()) {
58+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
59+ const int64_t out_ndim = std::max (in_sizes.size (), out_sizes.size ());
60+ for (int i = 0 ; i < out_ndim; i++) {
61+ const int64_t permute_dim = permute_dims.at (i);
62+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
63+ }
64+ graph->virtual_resize (out, new_out_sizes);
65+ }
66+ // Case where permute is being used to implement squeeze
67+ else if (
68+ in_sizes.size () > out_sizes.size () &&
69+ in_sizes.size () == permute_dims.size ()) {
70+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
71+ const size_t offset = in_sizes.size () - out_sizes.size ();
72+ for (int i = 0 ; i < out_sizes.size (); i++) {
73+ const int64_t permute_dim = permute_dims.at (i + offset);
74+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
75+ }
76+ graph->virtual_resize (out, new_out_sizes);
77+ }
78+ // Case where Permute is being used to implement unsqueeze
79+ else if (
80+ in_sizes.size () < out_sizes.size () &&
81+ out_sizes.size () == permute_dims.size ()) {
82+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
83+ const size_t offset = out_sizes.size () - in_sizes.size ();
84+ for (int i = 0 ; i < out_sizes.size (); i++) {
85+ int64_t permute_dim = permute_dims.at (i) - offset;
86+ if (permute_dim >= 0 ) {
87+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
88+ }
89+ }
90+ graph->virtual_resize (out, new_out_sizes);
91+ } else {
92+ VK_THROW (" Invalid permute dims" );
93+ }
94+ }
95+
4196void add_permute_node (
4297 ComputeGraph& graph,
43- ValueRef in,
44- const std::vector<int64_t >& permute_dims,
45- ValueRef out) {
46- vTensorPtr t_in = graph.get_tensor (in);
47- vTensorPtr t_out = graph.get_tensor (out);
48-
49- check_args (*t_in, permute_dims, *t_out);
98+ const ValueRef in,
99+ const ValueRef permute_dims,
100+ const ValueRef out) {
101+ check_args (graph, in, permute_dims, out);
50102
51103 ivec4 out_dims{0 , 1 , 2 , 3 };
52104
53105 // Special cases of squeeze/unsqueeze. Because the input dim size can be
54- // different with output dim size. So pick t_in->dim( ) if squeeze, and
55- // t_out->dim( ) if unsqueeze to create parameter for permute.
56- int64_t out_ndim = std::max (t_in-> dim ( ), t_out-> dim ( ));
106+ // different with output dim size. So pick graph.dim_of(in ) if squeeze, and
107+ // graph.dim_of(out ) if unsqueeze to create parameter for permute.
108+ const int64_t out_ndim = std::max (graph. dim_of (in ), graph. dim_of (out ));
57109 std::vector<bool > seen (out_ndim);
58- for (int i = 0 ; i < out_ndim; i++) {
59- int64_t permute_dim = permute_dims[i];
60- VK_CHECK_COND (
61- !seen[permute_dim], " Argument dim " , permute_dim, " is repeated" );
62- seen[permute_dim] = true ;
63-
64- out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
110+ {
111+ IntListPtr permute_dims_ptr = graph.get_int_list (permute_dims);
112+ for (int i = 0 ; i < out_ndim; i++) {
113+ int64_t permute_dim = permute_dims_ptr->at (i);
114+ VK_CHECK_COND (
115+ !seen[permute_dim], " Argument dim " , permute_dim, " is repeated" );
116+ seen[permute_dim] = true ;
117+
118+ out_dims[(4u - out_ndim) + i] =
119+ utils::safe_downcast<int32_t >(permute_dim + (4 - out_ndim));
120+ }
65121 }
66122
67123 std::string kernel_name = " permute" ;
68124 kernel_name.reserve (kShaderNameReserve );
69- add_dtype_suffix (kernel_name, *t_out );
125+ add_dtype_suffix (kernel_name, graph. dtype_of (out) );
70126
71- int32_t out_channels = dim_at<kChannel4D >(t_out-> sizes ( ));
72- int32_t in_channels = dim_at<kChannel4D >(t_in-> sizes ( ));
127+ const int32_t out_channels = dim_at<kChannel4D >(graph. sizes_of (out ));
128+ const int32_t in_channels = dim_at<kChannel4D >(graph. sizes_of (in ));
73129
74- const auto packed_dim = graph.packed_dim_of (in);
130+ const int32_t packed_dim = graph.packed_dim_of (in);
75131 ivec2 channel_info = {out_channels, in_channels};
76132 if (packed_dim == WHCN::kChannelsDim ) {
77133 channel_info[0 ] = utils::align_up_4 (channel_info[0 ]);
@@ -95,19 +151,9 @@ void add_permute_node(
95151 // Specialization Constants
96152 spec_vars,
97153 // Resize Args
98- {},
154+ {permute_dims },
99155 // Resizing Logic
100- nullptr ));
101- }
102-
103- void add_permute_node (
104- ComputeGraph& graph,
105- ValueRef in,
106- ValueRef permute_dims_ref,
107- ValueRef out) {
108- IntListPtr permute_dims = graph.get_int_list (permute_dims_ref);
109-
110- add_permute_node (graph, in, *permute_dims, out);
156+ resize_permute_node));
111157}
112158
113159void permute (ComputeGraph& graph, const std::vector<ValueRef>& args) {
0 commit comments