Skip to content

Commit fbd7fb5

Browse files
authored
[ET-VK] Add high dim support for permute (#13691)
Title says it all! Adding high dimension tensor support to permute by using the new `BufferMetadata` struct in the permute shader. Differential Revision: [D80962719](https://our.internmc.facebook.com/intern/diff/D80962719/) [ghstack-poisoned]
1 parent f15982d commit fbd7fb5

File tree

4 files changed

+114
-69
lines changed

4 files changed

+114
-69
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ def register_rotary_emb_op():
490490
@update_features(
491491
[
492492
exir_ops.edge.aten.permute.default,
493-
exir_ops.edge.aten.permute_copy.default,
494493
]
495494
)
496495
def register_view_ops():
@@ -506,6 +505,7 @@ def register_view_ops():
506505
exir_ops.edge.aten.squeeze_copy.dims,
507506
exir_ops.edge.aten.unsqueeze_copy.default,
508507
exir_ops.edge.aten.clone.default,
508+
exir_ops.edge.aten.permute_copy.default,
509509
]
510510
)
511511
def register_view_ops_with_buffer_meta():

backends/vulkan/runtime/graph/ops/glsl/indexing.glslh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ uint idx_at(const TensorIndex tidx, const int dim) {
9898
return tidx.data[div_4(dim)][mod_4(dim)];
9999
}
100100

101+
void permute(inout TensorIndex tidx, const ivec4 permute_order[DIMLIMIT_DIV4]) {
102+
TensorIndex new_tidx = tidx;
103+
for (int d = 0; d < DIMLIMIT; ++d) {
104+
int src_dim = permute_order[div_4(d)][mod_4(d)];
105+
new_tidx.data[div_4(d)][mod_4(d)] = idx_at(tidx, src_dim);
106+
}
107+
tidx = new_tidx;
108+
}
109+
101110
//
102111
// Index Conversions
103112
//

backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,55 +18,31 @@ ${define_required_extensions(DTYPE)}
1818

1919
layout(std430) buffer;
2020

21-
#include "indexing_utils.h"
21+
#include "indexing.glslh"
2222

23-
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
24-
${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")}
23+
${layout_declare_tensor(B, "w", "t_outp", DTYPE, "buffer")}
24+
${layout_declare_tensor(B, "r", "t_inp", DTYPE, "buffer")}
2525

26-
${layout_declare_ubo(B, "ivec4", "in_sizes")}
27-
${layout_declare_ubo(B, "ivec4", "out_strides")}
28-
${layout_declare_ubo(B, "int", "out_numel")}
26+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
27+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
2928

30-
layout(push_constant) uniform restrict Block {
31-
ivec4 in_strides;
32-
ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j
33-
};
34-
35-
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
36-
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
37-
38-
const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
29+
${layout_declare_ubo(B, "ivec4[DIMLIMIT_DIV4]", "permute_order")}
3930

4031
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4132

42-
// Convert output tensor index to input tensor index based on permutation
43-
ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) {
44-
ivec4 in_tidx;
45-
46-
// Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i]
47-
in_tidx[permute_dims.x] = out_tidx.x;
48-
in_tidx[permute_dims.y] = out_tidx.y;
49-
in_tidx[permute_dims.z] = out_tidx.z;
50-
in_tidx[permute_dims.w] = out_tidx.w;
51-
52-
return in_tidx;
53-
}
54-
5533
void main() {
56-
const int out_bufi = ivec3(gl_GlobalInvocationID).x;
57-
if (out_bufi >= out_numel) {
34+
const uint inp_bufi = gl_GlobalInvocationID.x;
35+
if (inp_bufi >= numel(inp)) {
5836
return;
5937
}
6038

61-
// Convert buffer index to tensor index for output
62-
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
63-
64-
// Convert output tensor index to input tensor index using permutation
65-
const ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx);
39+
TensorIndex inp_tidx;
40+
linear_idx_to_tensor_idx(inp, inp_bufi, inp_tidx);
6641

67-
// Convert input tensor index back to buffer index
68-
const int in_bufi = tidx_to_bufi(in_tidx, in_strides);
42+
TensorIndex outp_tidx = inp_tidx;
43+
permute(outp_tidx, permute_order);
6944

45+
const uint outp_bufi = tensor_idx_to_linear_idx(outp, outp_tidx);
7046
// Copy data from input to output
71-
t_out[out_bufi] = t_in[in_bufi];
47+
t_outp[outp_bufi] = t_inp[inp_bufi];
7248
}

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 90 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -129,37 +129,22 @@ void add_permute_node(
129129
std::vector<PushConstantDataInfo> push_constants;
130130
vkapi::SpecVarList spec_vars;
131131

132-
if (graph.is_buffer_storage(out)) {
133-
param_buffers.append(graph.sizes_ubo(in));
134-
param_buffers.append(graph.strides_ubo(out));
135-
param_buffers.append(graph.numel_ubo(out));
136-
137-
// Buffer storage - use permute_buffer shader
138-
push_constants = {
139-
graph.strides_pc_of(in),
140-
PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims)),
141-
};
142-
143-
spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)};
144-
} else {
145-
// Texture storage - use permute_texture shader
146-
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
147-
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));
148-
149-
const int32_t packed_dim = graph.packed_dim_of(in);
150-
ivec2 channel_info = {out_channels, in_channels};
151-
if (packed_dim == WHCN::kChannelsDim) {
152-
channel_info[0] = utils::align_up_4(channel_info[0]);
153-
channel_info[1] = utils::align_up_4(channel_info[1]);
154-
}
132+
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
133+
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));
134+
135+
const int32_t packed_dim = graph.packed_dim_of(in);
136+
ivec2 channel_info = {out_channels, in_channels};
137+
if (packed_dim == WHCN::kChannelsDim) {
138+
channel_info[0] = utils::align_up_4(channel_info[0]);
139+
channel_info[1] = utils::align_up_4(channel_info[1]);
140+
}
155141

156-
push_constants = {
157-
graph.sizes_pc_of(out),
158-
graph.sizes_pc_of(in),
159-
PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))};
142+
push_constants = {
143+
graph.sizes_pc_of(out),
144+
graph.sizes_pc_of(in),
145+
PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))};
160146

161-
spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)};
162-
}
147+
spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)};
163148

164149
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
165150
graph,
@@ -179,8 +164,83 @@ void add_permute_node(
179164
resize_permute_node));
180165
}
181166

167+
struct WHCNPermuteDims {
168+
int32_t whcn_permute_dims[api::kTensorDimLimit];
169+
170+
void initialize(const std::vector<int64_t>& permute_dims) {
171+
const int32_t permute_ndim = permute_dims.size();
172+
for (int32_t whcn_i = 0; whcn_i < permute_ndim; whcn_i++) {
173+
const int32_t nchw_i = permute_ndim - 1 - whcn_i;
174+
int64_t index_val = permute_dims.at(nchw_i);
175+
if (index_val < 0) {
176+
index_val += permute_ndim;
177+
}
178+
const int32_t permute_dim_whcn = permute_ndim - 1 - index_val;
179+
whcn_permute_dims[whcn_i] = permute_dim_whcn;
180+
}
181+
for (int32_t whcn_i = permute_ndim; whcn_i < api::kTensorDimLimit;
182+
whcn_i++) {
183+
whcn_permute_dims[whcn_i] = whcn_i;
184+
}
185+
}
186+
};
187+
188+
void add_permute_buffer_node(
189+
ComputeGraph& graph,
190+
const ValueRef in,
191+
const ValueRef permute_dims,
192+
const ValueRef out) {
193+
check_args(graph, in, permute_dims, out);
194+
195+
WHCNPermuteDims whcn_permute_dims;
196+
// Convert the permute dims to WHCN dimension order, which is the standard in
197+
// our compute shaders. The following transformations are applied.
198+
// 1. Change dimension index values from NCHW order valueto WHCN order value
199+
// 2. Extend the permute array to kTensorDimLimit
200+
{
201+
IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims);
202+
whcn_permute_dims.initialize(*permute_dims_ptr);
203+
}
204+
205+
std::string kernel_name = "permute";
206+
kernel_name.reserve(kShaderNameReserve);
207+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
208+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
209+
210+
vkapi::ParamsBindList param_buffers = {
211+
graph.buffer_meta_ubo(out),
212+
graph.buffer_meta_ubo(in),
213+
graph.create_params_buffer(whcn_permute_dims),
214+
};
215+
216+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
217+
graph,
218+
VK_KERNEL_FROM_STR(kernel_name),
219+
default_pick_global_wg_size,
220+
default_pick_local_wg_size,
221+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
222+
// Parameter buffers
223+
param_buffers,
224+
// Push Constants
225+
{},
226+
// Specialization Constants
227+
{},
228+
// Resize Args
229+
{permute_dims},
230+
// Resizing Logic
231+
resize_permute_node));
232+
}
233+
182234
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
183-
return add_permute_node(graph, args[0], args[1], args[2]);
235+
int idx = 0;
236+
const ValueRef in = args.at(idx++);
237+
const ValueRef permute_dims = args.at(idx++);
238+
const ValueRef out = args.at(idx++);
239+
240+
if (graph.is_buffer_storage(args[2])) {
241+
return add_permute_buffer_node(graph, in, permute_dims, out);
242+
}
243+
return add_permute_node(graph, in, permute_dims, out);
184244
}
185245

186246
REGISTER_OPERATORS {

0 commit comments

Comments
 (0)