Skip to content

Commit d056881

Browse files
committed
[ET-VK] New implementation of cat operator
## Changes * Introduce `concat_texture.glsl` and `concat_buffer.glsl` to implement the `torch.cat` operator * Introduce `Concat.cpp` to replace `Cat.cpp` * Fix a bug with channels-packed buffer tensors where input data would be copied incorrectly with multiple dims have a stride of 1 ## Motivation > * Introduce `concat_texture.glsl` and `concat_buffer.glsl` to implement the `torch.cat` operator > * Introduce `Concat.cpp` to replace `Cat.cpp` The existing implementation of `torch.cat` uses the copy_channel_offset` shaders. However, these shaders have a critical bug where the output tensor is passed in separately with difference access types, i.e. ``` graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), global_size, local_size, // Inputs and Outputs { {out, vkapi::kWrite}, {out, vkapi::kRead}, {in, vkapi::kRead}, }, ``` This creates many validation layer errors because the memory barriers for the resource cannot be formed properly. The shader essentially relies on undefined behaviour to work correctly. The result is that the `cat` operator produces incorrect result on many platforms. Rather than fix the `copy_offset` shaders, I decided to just introduce new shaders to perform the concat operation. The new implementation handles both buffer and texture inputs and is agnostic to memory layout. Differential Revision: [D76305343](https://our.internmc.facebook.com/intern/diff/D76305343/) [ghstack-poisoned]
1 parent c2aa614 commit d056881

File tree

10 files changed

+362
-125
lines changed

10 files changed

+362
-125
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
15+
16+
${define_active_storage_type("buffer")}
17+
${define_required_extensions(DTYPE)}
18+
19+
layout(std430) buffer;
20+
21+
#include "indexing_utils.h"
22+
23+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
24+
25+
$for i in range(NUM_INPUTS):
26+
${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "buffer")}
27+
28+
${layout_declare_ubo(B, "int", "concat_dim")}
29+
30+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
31+
${layout_declare_ubo(B, "ivec4", "out_strides")}
32+
33+
$for i in range(NUM_INPUTS):
34+
${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")}
35+
${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_strides")}
36+
37+
${layout_declare_ubo(B, "int", "out_numel")}
38+
39+
${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")}
40+
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
void main() {
44+
const int out_bufi = ivec3(gl_GlobalInvocationID).x;
45+
if (out_bufi >= out_numel) {
46+
return;
47+
}
48+
49+
// Convert buffer linear index to 4-D tensor index for output
50+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
51+
52+
// Determine which input tensor to read from
53+
ivec4 in_tidx = out_tidx;
54+
55+
$for i in range(NUM_INPUTS):
56+
// Check if the index at the concat dim is within bounds of the input tensor
57+
// If so, read from that input tensor and write to output
58+
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
59+
int in_bufi = tidx_to_bufi(in_tidx, in${i+1}_strides);
60+
t_out[out_bufi] = t_in${i+1}[in_bufi];
61+
return;
62+
}
63+
// otherwise, decrement the index at the concat dim
64+
else {
65+
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
66+
}
67+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
concat_buffer:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NUM_INPUTS: 2
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: concat_1_buffer
11+
NUM_INPUTS: 1
12+
- NAME: concat_2_buffer
13+
- NAME: concat_3_buffer
14+
NUM_INPUTS: 3
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
15+
16+
#define USING_TEXTURE3D
17+
18+
layout(std430) buffer;
19+
20+
#include "indexing_utils.h"
21+
22+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
23+
24+
$for i in range(NUM_INPUTS):
25+
${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "texture3d")}
26+
27+
${layout_declare_ubo(B, "int", "concat_dim")}
28+
29+
$in_metadata = ""
30+
$for i in range(NUM_INPUTS):
31+
$in_metadata += "ivec4 in" + str(i + 1) + "_sizes;\n"
32+
33+
layout(push_constant) uniform restrict Block {
34+
ivec4 out_sizes;
35+
${in_metadata}
36+
};
37+
38+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
39+
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
40+
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
41+
42+
$for i in range(NUM_INPUTS):
43+
${layout_declare_spec_const(C, "int", "in" + str(i+1) + "_layout", "DEFAULT_LAYOUT")}
44+
const lowp ivec4 in${i+1}_axis_map = unhash_axis_map(in${i+1}_layout);
45+
const lowp int in${i+1}_packed_dim = unhash_packed_dim(in${i+1}_layout);
46+
47+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48+
49+
// Check if we can use the fast path (no texel merging required)
50+
bool can_use_fast_path() {
51+
// Fast path is possible when:
52+
// 1. The concat dimension is not the packed dimension, or
53+
// 2. The concat dimension is the packed dimension but both input tensors have dimensions
54+
// that are multiples of 4 along the packed dimension
55+
if (concat_dim != out_packed_dim) {
56+
return true;
57+
}
58+
59+
// Check if all input tensors have dimensions that are multiples of 4 along the packed dimension
60+
bool all_concat_dim_size_multiple_of_4 = true;
61+
$for i in range(NUM_INPUTS):
62+
all_concat_dim_size_multiple_of_4 =
63+
all_concat_dim_size_multiple_of_4 &&
64+
(in${i+1}_sizes[concat_dim] % 4 == 0);
65+
66+
return all_concat_dim_size_multiple_of_4;
67+
}
68+
69+
void main() {
70+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
71+
ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim);
72+
73+
if (any(greaterThanEqual(out_tidx, out_sizes))) {
74+
return;
75+
}
76+
77+
if (can_use_fast_path()) {
78+
// Fast path: No texel merging required
79+
ivec4 in_tidx = out_tidx;
80+
81+
$for i in range(NUM_INPUTS):
82+
// For each input tensor, check if the tensor index is within bounds. If
83+
// so, read the texel from the input tensor and write it to the output
84+
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
85+
const ivec3 in_pos = tidx_to_pos(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim);
86+
const VEC4_T in_texel = load_texel(t_in${i+1}, in_pos);
87+
write_texel_lpos(t_out, lpos, in_texel, out_axis_map);
88+
return;
89+
}
90+
// Otherwise, adjust the index along the concat dimension and try the next
91+
// input tensor.
92+
else {
93+
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
94+
}
95+
}
96+
else {
97+
// Slow path: Texel merging required
98+
VEC4_T out_texel = VEC4_T(0);
99+
100+
// Process each element in the output texel individually
101+
for (int texel_i = 0; texel_i < 4; ++texel_i) {
102+
ivec4 curr_out_tidx = out_tidx;
103+
curr_out_tidx[out_packed_dim] += texel_i;
104+
105+
// Skip if we're out of bounds
106+
if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) {
107+
continue;
108+
}
109+
110+
ivec4 in_tidx = curr_out_tidx;
111+
$for i in range(NUM_INPUTS):
112+
// For each input tensor, check if the tensor index is within bounds. If
113+
// so, read the corresponding texel element from the input tensor and
114+
// write it to the output texel.
115+
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
116+
const ivec4 in_posi = tidx_to_posi(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim);
117+
out_texel[texel_i] = load_texel(t_in${i+1}, in_posi.xyz)[in_posi.w];
118+
continue;
119+
}
120+
// Otherwise, adjust the index along the concat dimension and try the
121+
// next input tensor.
122+
else {
123+
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
124+
}
125+
}
126+
127+
write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
128+
}
129+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
concat_texture:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NUM_INPUTS: 2
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: concat_1_texture3d
11+
NUM_INPUTS: 1
12+
- NAME: concat_2_texture3d
13+
- NAME: concat_3_texture3d
14+
NUM_INPUTS: 3

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,6 @@
6868
*/
6969
#define mod4(x) ((x) & 3)
7070

71-
/*
72-
* Find the packed dimension of a tensor given its strides. The packed dimension
73-
* is the "fastest moving" dimension which will have a stride of 1.
74-
*/
75-
int find_packed_dim(const ivec4 strides) {
76-
int packed_dim = 0;
77-
for (int i = 0; i <= 3; i++) {
78-
if (strides[i] == 1) {
79-
packed_dim = i;
80-
break;
81-
}
82-
}
83-
return packed_dim;
84-
}
85-
8671
/*
8772
* Get the staging buffer indices that contain the data of the texel that
8873
* corresponds to the provided tensor index. Since the texel have 4 elements,
@@ -144,14 +129,6 @@ ivec4 bufi_to_tidx(int bufi, const ivec4 strides, const int packed_dim) {
144129
return idx;
145130
}
146131

147-
// Convenience overload of the above function, which will determine the packed
148-
// dim from the strides automatically so it doesn't have to be passed in as a
149-
// function argument.
150-
ivec4 bufi_to_tidx(const int bufi, const ivec4 strides) {
151-
int packed_dim = find_packed_dim(strides);
152-
return bufi_to_tidx(bufi, strides, packed_dim);
153-
}
154-
155132
int tidx_to_bufi(const ivec4 tidx, ivec4 strides) {
156133
return tidx.x * strides.x + tidx.y * strides.y + tidx.z * strides.z +
157134
tidx.w * strides.w;

backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2828

2929
// This constant is unused in this shader but is kept so that the signature is
3030
// consistent with nchw_to_image.
31-
${layout_declare_spec_const(C, "int", "UNUSED_layout", "0")}
31+
${layout_declare_spec_const(C, "int", "packed_dim", "0")}
3232
${layout_declare_spec_const(C, "int", "transpose_hw", "0")}
3333

3434
void main() {
@@ -37,7 +37,7 @@ void main() {
3737
return;
3838
}
3939

40-
ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides);
40+
ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, packed_dim);
4141

4242
ivec4 sizes = out_sizes;
4343
if (transpose_hw == 1) {

backends/vulkan/runtime/graph/ops/impl/Cat.cpp

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)