Skip to content

Commit e62a4ef

Browse files
authored
[ET-VK] Use dim order when converting buffer index to tensor index (#11622)
## Changes * Update callsites to `bufi_to_tidx` to account for the tensor dim order * Remove existing functions which do not accept dim order as argument. ## Motivation > Update callsites to `bufi_to_tidx` to account for the tensor dim order > Remove existing functions which do not accept dim order as argument. As mentioned in the below diff, dim order is required to properly convert from a linear buffer index to N-dimension tensor index using a tensor's strides. Technically the dim order can be inferred from the strides array by performing an index sort. However, for the sake of efficiency it is better to just pass the dim order directly into the compute shader. Currently the `bufi_to_tidx` function which performs the conversion between buffer index and tensor index assumes that the dim order follows a specific pattern using the packed dim as an input. However, it is not guaranteed that the dim order is the same as what is assumed. Furthermore, there is an existing bug when calling `bufi_to_tidx` without providing `packed_dim` as an input. In this case, the function will infer the packed dim by finding the first dim with a stride of 1. However, this causes issues when multiple dims may have a stride of 1, which may occur when there are dims with a size of 1. In this case the wrong packed dim may be inferred and therefore the assumed dim order is completely wrong. To address these issues, make it standard to either account for the packed dim when converting bufi to tidx, or to explicitly call out an assumption about the tensor's dim order. ## Performance Impact * None expected Differential Revision: [D76393428](https://our.internmc.facebook.com/intern/diff/D76393428/)
1 parent cef2094 commit e62a4ef

File tree

14 files changed

+82
-92
lines changed

14 files changed

+82
-92
lines changed

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,18 @@ $else:
4848

4949
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5050

51+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
52+
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
53+
${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")}
54+
5155
$if STORAGE == "buffer":
52-
${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")}
53-
${layout_declare_spec_const(C, "int", "in_packed_dim", "DEFAULT_LAYOUT")}
54-
${layout_declare_spec_const(C, "int", "other_packed_dim", "DEFAULT_LAYOUT")}
56+
const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
5557
$else:
56-
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
5758
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
5859
const lowp int packed_dim = unhash_packed_dim(out_layout);
5960

60-
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
6161
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
6262

63-
${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")}
6463
const lowp ivec4 other_axis_map = unhash_axis_map(other_layout);
6564

6665
#ifdef USING_BUFFER
@@ -77,7 +76,7 @@ void main() {
7776
return;
7877
}
7978

80-
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
79+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
8180
const ivec4 in_tidx = min(out_tidx, in_sizes - 1);
8281
const ivec4 other_tidx = min(out_tidx, other_sizes - 1);
8382

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,6 @@
6868
*/
6969
#define mod4(x) ((x) & 3)
7070

71-
/*
72-
* Find the packed dimension of a tensor given its strides. The packed dimension
73-
* is the "fastest moving" dimension which will have a stride of 1.
74-
*/
75-
int find_packed_dim(const ivec4 strides) {
76-
int packed_dim = 0;
77-
for (int i = 0; i <= 3; i++) {
78-
if (strides[i] == 1) {
79-
packed_dim = i;
80-
break;
81-
}
82-
}
83-
return packed_dim;
84-
}
85-
8671
/*
8772
* Get the staging buffer indices that contain the data of the texel that
8873
* corresponds to the provided tensor index. Since the texel have 4 elements,
@@ -129,27 +114,26 @@ int tidx_to_nchwi(const ivec4 tidx, const ivec4 sizes) {
129114
tidx.x;
130115
}
131116

132-
// TODO(ssjia): make this function use dim order so that it can work with any
133-
// dim order. Currently it assumes that the dim order is contiguous, except for
134-
// the packed dim.
135-
ivec4 bufi_to_tidx(int bufi, const ivec4 strides, const int packed_dim) {
117+
ivec4 bufi_to_tidx(int bufi, const ivec4 strides, const ivec4 dim_order) {
136118
ivec4 idx;
137119
for (int i = 3; i >= 0; i--) {
138-
if (i != packed_dim) {
139-
idx[i] = bufi / strides[i];
140-
bufi %= strides[i];
141-
}
120+
int dim = dim_order[i];
121+
idx[dim] = bufi / strides[dim];
122+
bufi %= strides[dim];
142123
}
143-
idx[packed_dim] = bufi;
144124
return idx;
145125
}
146126

147-
// Convenience overload of the above function, which will determine the packed
148-
// dim from the strides automatically so it doesn't have to be passed in as a
149-
// function argument.
150-
ivec4 bufi_to_tidx(const int bufi, const ivec4 strides) {
151-
int packed_dim = find_packed_dim(strides);
152-
return bufi_to_tidx(bufi, strides, packed_dim);
127+
/*
128+
* bufi_to_tidx but assumes that the tensor is contiguous
129+
*/
130+
ivec4 contiguous_bufi_to_tidx(int bufi, const ivec4 strides) {
131+
ivec4 idx;
132+
for (int i = 3; i >= 0; i--) {
133+
idx[i] = bufi / strides[i];
134+
bufi %= strides[i];
135+
}
136+
return idx;
153137
}
154138

155139
int tidx_to_bufi(const ivec4 tidx, ivec4 strides) {
@@ -269,12 +253,22 @@ ivec3 lpos_to_pos(const ivec3 lpos, const ivec4 axis_map) {
269253
* e.g. 0x11021, 1 -> ivec4(1, 2, 0, 1)
270254
*/
271255
#define unhash_axis_map(hash) \
272-
ivec4(hash & 0xf, (hash >> 4) & 0xf, (hash >> 8 & 0xf), (hash >> 12 & 0xf))
256+
(ivec4(hash & 0xf, (hash >> 4) & 0xf, (hash >> 8 & 0xf), (hash >> 12 & 0xf)))
257+
258+
/*
259+
*
260+
*/
261+
#define unhash_dim_order(hash) \
262+
(ivec4(hash & 0xf, (hash >> 4) & 0xf, (hash >> 8 & 0xf), (hash >> 12 & 0xf)))
273263

274264
#define unhash_packed_dim(hash) int(hash >> 16 & 0xf)
275265

276266
#define DEFAULT_LAYOUT 0x02210
277267

268+
#define DEFAULT_DIM_ORDER 0x03210
269+
270+
#define DEFAULT_DIM_ORDER_IVEC4 ivec4(0, 1, 2, 3)
271+
278272
/************************
279273
* Deprecated Functions *
280274
************************/

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void main() {
6262
return;
6363
}
6464

65-
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0);
65+
const ivec4 out_tidx = contiguous_bufi_to_tidx(out_bufi, out_strides);
6666

6767
const FLOAT_T scale = t_scales[out_tidx.x];
6868

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ ${define_required_extensions(DTYPE)}
1010

1111
layout(std430) buffer;
1212

13-
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
14-
${layout_declare_tensor(1, "r", "nchw_in", DTYPE, STORAGE)}
13+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
14+
${layout_declare_tensor(B, "r", "nchw_in", DTYPE, STORAGE)}
1515

1616
$if USE_PUSH_CONST:
1717
layout(push_constant) uniform restrict Block {
@@ -20,15 +20,14 @@ $if USE_PUSH_CONST:
2020
int numel;
2121
};
2222
$else:
23-
${layout_declare_ubo(2, "ivec4", "out_sizes")}
24-
${layout_declare_ubo(3, "ivec4", "out_strides")}
25-
${layout_declare_ubo(4, "int", "numel")}
23+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
24+
${layout_declare_ubo(B, "ivec4", "out_strides")}
25+
${layout_declare_ubo(B, "int", "numel")}
2626

2727
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2828

29-
// This constant is unused in this shader but is kept so that the signature is
30-
// consistent with nchw_to_image.
31-
${layout_declare_spec_const(C, "int", "UNUSED_layout", "0")}
29+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_DIM_ORDER")}
30+
const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
3231
${layout_declare_spec_const(C, "int", "transpose_hw", "0")}
3332

3433
void main() {
@@ -37,7 +36,7 @@ void main() {
3736
return;
3837
}
3938

40-
ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides);
39+
ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
4140

4241
ivec4 sizes = out_sizes;
4342
if (transpose_hw == 1) {

backends/vulkan/runtime/graph/ops/glsl/select.glslh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef SELECT_GLSLH
1010
#define SELECT_GLSLH
1111

12+
#ifndef USING_BUFFER
13+
1214
/*
1315
* Enable the fast path if a texel loaded from the input texture can be used as
1416
* is to store to the output texture. The following conditions must be met:
@@ -29,6 +31,8 @@ bool can_use_fast_path() {
2931
return true;
3032
}
3133

34+
#endif // USING_BUFFER
35+
3236
/*
3337
* Given an output tensor index, return the corresponding input tensor index for
3438
* the select operator. This is done by "inserting" the select index at the

backends/vulkan/runtime/graph/ops/glsl/slice.glslh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#ifndef SLICE_GLSLH
1010
#define SLICE_GLSLH
1111

12+
#ifndef USING_BUFFER
13+
1214
/**
1315
* Enable the fast path if a texel loaded from the input texture can be used as
1416
* is to store to the output texture. The following conditions must be met:
@@ -26,6 +28,8 @@ bool can_use_fast_path() {
2628
return true;
2729
}
2830

31+
#endif // USING_BUFFER
32+
2933
/*
3034
* Converts output tensor indices to input tensor indices for the slice operation.
3135
* This function maps the output indices to the corresponding input indices based on

backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ layout(push_constant) uniform restrict Block {
3737
int selected_dim;
3838
};
3939

40-
${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")}
41-
${layout_declare_spec_const(C, "int", "in_packed_dim", "DEFAULT_LAYOUT")}
40+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
41+
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
42+
43+
const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
4244

4345
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4446

@@ -50,7 +52,7 @@ void main() {
5052
return;
5153
}
5254

53-
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
55+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
5456
ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx);
5557

5658
const int in_bufi = tidx_to_bufi(in_tidx, in_strides);

backends/vulkan/runtime/graph/ops/glsl/where.glsl

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,40 +37,28 @@ $if STORAGE == "buffer":
3737
${layout_declare_ubo(B, "ivec4", "cond_strides")}
3838
${layout_declare_ubo(B, "ivec4", "self_strides")}
3939
${layout_declare_ubo(B, "ivec4", "other_strides")}
40-
41-
${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")}
42-
${layout_declare_spec_const(C, "int", "cond_packed_dim", "DEFAULT_LAYOUT")}
43-
${layout_declare_spec_const(C, "int", "self_packed_dim", "DEFAULT_LAYOUT")}
44-
${layout_declare_spec_const(C, "int", "other_packed_dim", "DEFAULT_LAYOUT")}
4540
$else:
4641
${layout_declare_ubo(B, "ivec3", "out_limits")}
4742

43+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_DIM_ORDER")}
44+
45+
const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
46+
4847
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4948

5049
#ifdef USING_BUFFER
5150

5251
void main() {
5352
int out_bufi = int(gl_GlobalInvocationID.x);
54-
// ivec4 tidx = ivec4(gl_GlobalInvocationID, 0);
55-
// int out_bufi = tidx_to_bufi(tidx, out_strides);
56-
// int cond_bufi = tidx_to_bufi(tidx, cond_strides);
57-
// int self_bufi = tidx_to_bufi(tidx, self_strides);
58-
// int other_bufi = tidx_to_bufi(tidx, other_strides);
5953
if (out_bufi >= out_numl) {
6054
return;
6155
}
6256

63-
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
64-
out_bufi = tidx_to_bufi(out_tidx, out_strides);
65-
66-
const ivec4 cond_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
67-
const int cond_bufi = tidx_to_bufi(cond_tidx, cond_strides);
68-
69-
const ivec4 self_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
70-
const int self_bufi = tidx_to_bufi(self_tidx, self_strides);
57+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
7158

72-
const ivec4 other_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
73-
const int other_bufi = tidx_to_bufi(other_tidx, other_strides);
59+
const int cond_bufi = tidx_to_bufi(out_tidx, cond_strides);
60+
const int self_bufi = tidx_to_bufi(out_tidx, self_strides);
61+
const int other_bufi = tidx_to_bufi(out_tidx, other_strides);
7462

7563
COND_T cond = t_condition[cond_bufi] ;
7664
T v_self = t_self[self_bufi];

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ void add_binary_op_buffer_node(
143143
PushConstantDataInfo(&alpha_val, sizeof(float)),
144144
}},
145145
// Specialization Constants
146-
{graph.packed_dim_of(out),
147-
graph.packed_dim_of(in1),
148-
graph.packed_dim_of(in2)},
146+
{graph.hashed_layout_of(out),
147+
graph.hashed_layout_of(in1),
148+
graph.hashed_layout_of(in2)},
149149
// Resize Args
150150
{},
151151
// Resizing Logic

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ void check_linear_qcsnw_args(
4343
VK_CHECK_COND(
4444
utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes));
4545
}
46+
47+
if (graph.is_buffer_storage(out)) {
48+
VK_CHECK_COND(graph.is_contiguous(out));
49+
}
4650
}
4751

4852
void resize_linear_qcsnw_node(

0 commit comments

Comments
 (0)