@@ -129,37 +129,22 @@ void add_permute_node(
129
129
std::vector<PushConstantDataInfo> push_constants;
130
130
vkapi::SpecVarList spec_vars;
131
131
132
- if (graph.is_buffer_storage (out)) {
133
- param_buffers.append (graph.sizes_ubo (in));
134
- param_buffers.append (graph.strides_ubo (out));
135
- param_buffers.append (graph.numel_ubo (out));
136
-
137
- // Buffer storage - use permute_buffer shader
138
- push_constants = {
139
- graph.strides_pc_of (in),
140
- PushConstantDataInfo (&whcn_permute_dims, sizeof (whcn_permute_dims)),
141
- };
142
-
143
- spec_vars = {graph.hashed_layout_of (out), graph.hashed_layout_of (in)};
144
- } else {
145
- // Texture storage - use permute_texture shader
146
- const int32_t out_channels = dim_at<kChannel4D >(graph.sizes_of (out));
147
- const int32_t in_channels = dim_at<kChannel4D >(graph.sizes_of (in));
148
-
149
- const int32_t packed_dim = graph.packed_dim_of (in);
150
- ivec2 channel_info = {out_channels, in_channels};
151
- if (packed_dim == WHCN::kChannelsDim ) {
152
- channel_info[0 ] = utils::align_up_4 (channel_info[0 ]);
153
- channel_info[1 ] = utils::align_up_4 (channel_info[1 ]);
154
- }
132
+ const int32_t out_channels = dim_at<kChannel4D >(graph.sizes_of (out));
133
+ const int32_t in_channels = dim_at<kChannel4D >(graph.sizes_of (in));
134
+
135
+ const int32_t packed_dim = graph.packed_dim_of (in);
136
+ ivec2 channel_info = {out_channels, in_channels};
137
+ if (packed_dim == WHCN::kChannelsDim ) {
138
+ channel_info[0 ] = utils::align_up_4 (channel_info[0 ]);
139
+ channel_info[1 ] = utils::align_up_4 (channel_info[1 ]);
140
+ }
155
141
156
- push_constants = {
157
- graph.sizes_pc_of (out),
158
- graph.sizes_pc_of (in),
159
- PushConstantDataInfo (&whcn_permute_dims, sizeof (whcn_permute_dims))};
142
+ push_constants = {
143
+ graph.sizes_pc_of (out),
144
+ graph.sizes_pc_of (in),
145
+ PushConstantDataInfo (&whcn_permute_dims, sizeof (whcn_permute_dims))};
160
146
161
- spec_vars = {graph.hashed_layout_of (out), graph.hashed_layout_of (in)};
162
- }
147
+ spec_vars = {graph.hashed_layout_of (out), graph.hashed_layout_of (in)};
163
148
164
149
graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
165
150
graph,
@@ -179,8 +164,83 @@ void add_permute_node(
179
164
resize_permute_node));
180
165
}
181
166
167
+ struct WHCNPermuteDims {
168
+ int32_t whcn_permute_dims[api::kTensorDimLimit ];
169
+
170
+ void initialize (const std::vector<int64_t >& permute_dims) {
171
+ const int32_t permute_ndim = permute_dims.size ();
172
+ for (int32_t whcn_i = 0 ; whcn_i < permute_ndim; whcn_i++) {
173
+ const int32_t nchw_i = permute_ndim - 1 - whcn_i;
174
+ int64_t index_val = permute_dims.at (nchw_i);
175
+ if (index_val < 0 ) {
176
+ index_val += permute_ndim;
177
+ }
178
+ const int32_t permute_dim_whcn = permute_ndim - 1 - index_val;
179
+ whcn_permute_dims[whcn_i] = permute_dim_whcn;
180
+ }
181
+ for (int32_t whcn_i = permute_ndim; whcn_i < api::kTensorDimLimit ;
182
+ whcn_i++) {
183
+ whcn_permute_dims[whcn_i] = whcn_i;
184
+ }
185
+ }
186
+ };
187
+
188
+ void add_permute_buffer_node (
189
+ ComputeGraph& graph,
190
+ const ValueRef in,
191
+ const ValueRef permute_dims,
192
+ const ValueRef out) {
193
+ check_args (graph, in, permute_dims, out);
194
+
195
+ WHCNPermuteDims whcn_permute_dims;
196
+ // Convert the permute dims to WHCN dimension order, which is the standard in
197
+ // our compute shaders. The following transformations are applied.
198
+ // 1. Change dimension index values from NCHW order valueto WHCN order value
199
+ // 2. Extend the permute array to kTensorDimLimit
200
+ {
201
+ IntListPtr permute_dims_ptr = graph.get_int_list (permute_dims);
202
+ whcn_permute_dims.initialize (*permute_dims_ptr);
203
+ }
204
+
205
+ std::string kernel_name = " permute" ;
206
+ kernel_name.reserve (kShaderNameReserve );
207
+ add_storage_type_suffix (kernel_name, graph.storage_type_of (out));
208
+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
209
+
210
+ vkapi::ParamsBindList param_buffers = {
211
+ graph.buffer_meta_ubo (out),
212
+ graph.buffer_meta_ubo (in),
213
+ graph.create_params_buffer (whcn_permute_dims),
214
+ };
215
+
216
+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
217
+ graph,
218
+ VK_KERNEL_FROM_STR (kernel_name),
219
+ default_pick_global_wg_size,
220
+ default_pick_local_wg_size,
221
+ {{out, vkapi::kWrite }, {in, vkapi::kRead }},
222
+ // Parameter buffers
223
+ param_buffers,
224
+ // Push Constants
225
+ {},
226
+ // Specialization Constants
227
+ {},
228
+ // Resize Args
229
+ {permute_dims},
230
+ // Resizing Logic
231
+ resize_permute_node));
232
+ }
233
+
182
234
void permute (ComputeGraph& graph, const std::vector<ValueRef>& args) {
183
- return add_permute_node (graph, args[0 ], args[1 ], args[2 ]);
235
+ int idx = 0 ;
236
+ const ValueRef in = args.at (idx++);
237
+ const ValueRef permute_dims = args.at (idx++);
238
+ const ValueRef out = args.at (idx++);
239
+
240
+ if (graph.is_buffer_storage (args[2 ])) {
241
+ return add_permute_buffer_node (graph, in, permute_dims, out);
242
+ }
243
+ return add_permute_node (graph, in, permute_dims, out);
184
244
}
185
245
186
246
REGISTER_OPERATORS {
0 commit comments