Skip to content

Commit fabce25

Browse files
authored
[ET-VK] Allow aten.cat.default to handle any number of input tensors (#13252)
## Context Previously, I updated the implementation of `aten.cat.default` in D76305343 (#11508) since the original implementation had a bug. The new implementation only supported up to 3 input tensors, but several models require the need for up to 6 input tensors. This diff updates the capabilities of the `concat` op so that any arbitrary number of input tensors may be accepted. ## Changes * Update implementation of the concat shader to be able to be called repeatedly, allowing support for any number of input tensors. Differential Revision: [D79893084](https://our.internmc.facebook.com/intern/diff/D79893084/)
1 parent 6fd97ab commit fabce25

File tree

14 files changed

+790
-235
lines changed

14 files changed

+790
-235
lines changed

backends/vulkan/op_registry.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -491,17 +491,9 @@ def register_view_ops():
491491
# for both texture and buffer storage types.
492492
@update_features(exir_ops.edge.aten.cat.default)
493493
def register_cat_op():
494-
def check_cat_node(node: torch.fx.Node) -> bool:
495-
inputs = node.args[0]
496-
if isinstance(inputs, (list, tuple)) and len(inputs) <= 3:
497-
return True
498-
499-
return False
500-
501494
return OpFeatures(
502495
inputs_storage=utils.ANY_STORAGE,
503496
supports_resize=True,
504-
are_node_inputs_supported_fn=check_cat_node,
505497
)
506498

507499

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ void vTensorStorage::transition(
517517
vkapi::MemoryAccessFlags prev_access = last_access_.access;
518518

519519
const bool prev_written = (prev_access & vkapi::MemoryAccessType::WRITE) != 0;
520+
const bool cur_written = (cur_access & vkapi::MemoryAccessType::WRITE) != 0;
520521

521522
VkImageLayout cur_layout = VK_IMAGE_LAYOUT_UNDEFINED;
522523
VkImageLayout new_layout = VK_IMAGE_LAYOUT_UNDEFINED;
@@ -528,7 +529,13 @@ void vTensorStorage::transition(
528529
layout_changed = cur_layout != new_layout;
529530
}
530531

531-
if (prev_written || layout_changed) {
532+
// RAW: need to make sure current read sees previous writes
533+
// WAW: need to make sure the current write occurs after previous write so
534+
// the final value is correct.
535+
// WAR: need to make sure previous read does not read the value from the
536+
// current write.
537+
// RAR: no need for synchronization
538+
if (prev_written || cur_written || layout_changed) {
532539
VkPipelineStageFlags src_stage = vkapi::vk_stage(prev_stage);
533540
if (0u == src_stage) {
534541
src_stage = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT;

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class ExecuteNode {
4343
friend class ComputeGraph;
4444

4545
public:
46-
using ResizeFunction = const std::function<void(
46+
using ResizeFunction = std::function<void(
4747
ComputeGraph*,
4848
const std::vector<ArgGroup>&,
4949
const std::vector<ValueRef>&)>;

backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,21 @@ layout(std430) buffer;
2020

2121
#include "indexing_utils.h"
2222

23-
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
23+
${layout_declare_tensor(B, "rw", "t_out", DTYPE, "buffer")}
2424

2525
$for i in range(NUM_INPUTS):
26-
${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "buffer")}
26+
${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "buffer")}
27+
28+
${layout_declare_tensor(B, "r", "t_concat_offset", "int", "buffer")}
2729

2830
${layout_declare_ubo(B, "int", "concat_dim")}
2931

3032
${layout_declare_ubo(B, "ivec4", "out_sizes")}
3133
${layout_declare_ubo(B, "ivec4", "out_strides")}
3234

3335
$for i in range(NUM_INPUTS):
34-
${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")}
35-
${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_strides")}
36+
${layout_declare_ubo(B, "ivec4", "inp" + str(i) + "_sizes")}
37+
${layout_declare_ubo(B, "ivec4", "inp" + str(i) + "_strides")}
3638

3739
${layout_declare_ubo(B, "int", "out_numel")}
3840

@@ -42,28 +44,53 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
4244

4345
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4446

47+
#define NUM_INPUTS ${NUM_INPUTS}
48+
49+
#include "concat_utils.glslh"
50+
51+
/*
52+
* This shader template concatenates up to NUM_INPUT input tensors to the
53+
* output tensor along the concat_dim. Elements from the input tensor will
54+
* be inserted along the output's concat_dim starting at concat_offset.
55+
*/
4556
void main() {
46-
const int out_bufi = ivec3(gl_GlobalInvocationID).x;
47-
if (out_bufi >= out_numel) {
57+
const int tid = ivec3(gl_GlobalInvocationID).x;
58+
59+
// The 1-3 input tensors are interpreted as one concatenated tensor ("volume")
60+
// along the concat_dim for the purposes of tensor indexing. Each thread is
61+
// responsible for reading one item from this volume and writing it to the
62+
// appropriate output location.
63+
ivec4 inp_volume_sizes = out_sizes;
64+
inp_volume_sizes[concat_dim] = total_concat_dim_numel();
65+
66+
// Account for 0 size input tensors
67+
if (any(lessThanEqual(inp_volume_sizes, ivec4(0)))) {
68+
return;
69+
}
70+
71+
ivec4 inp_volume_tidx = nchwi_to_tidx(tid, inp_volume_sizes);
72+
73+
// bounds check
74+
if (any(greaterThanEqual(inp_volume_tidx, inp_volume_sizes))) {
4875
return;
4976
}
5077

51-
// Convert buffer linear index to 4-D tensor index for output
52-
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
78+
int concat_offset = t_concat_offset[0];
79+
80+
ivec4 out_tidx = inp_volume_tidx;
81+
out_tidx[concat_dim] += concat_offset;
5382

54-
// Determine which input tensor to read from
55-
ivec4 in_tidx = out_tidx;
83+
const uint out_bufi = tidx_to_bufi(out_tidx, out_strides);
5684

85+
// Go through the list of input tensors, and find which input this output
86+
// element should be read from.
5787
$for i in range(NUM_INPUTS):
58-
// Check if the index at the concat dim is within bounds of the input tensor
59-
// If so, read from that input tensor and write to output
60-
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
61-
int in_bufi = tidx_to_bufi(in_tidx, in${i+1}_strides);
62-
t_out[out_bufi] = t_in${i+1}[in_bufi];
88+
if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) {
89+
int inp_bufi = tidx_to_bufi(inp_volume_tidx, inp${i}_strides);
90+
t_out[out_bufi] = t_inp${i}[inp_bufi];
6391
return;
6492
}
65-
// otherwise, decrement the index at the concat dim
6693
else {
67-
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
94+
inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim];
6895
}
6996
}

backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl

Lines changed: 120 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@ layout(std430) buffer;
1919

2020
#include "indexing_utils.h"
2121

22-
${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
22+
${layout_declare_tensor(B, "rw", "t_out", DTYPE, "texture3d")}
2323

2424
$for i in range(NUM_INPUTS):
25-
${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "texture3d")}
25+
${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "texture3d")}
26+
27+
${layout_declare_tensor(B, "r", "t_concat_offset", "int", "buffer")}
2628

2729
${layout_declare_ubo(B, "int", "concat_dim")}
2830

2931
$in_metadata = ""
3032
$for i in range(NUM_INPUTS):
31-
$in_metadata += "ivec4 in" + str(i + 1) + "_sizes;\n"
33+
$in_metadata += "ivec4 inp" + str(i) + "_sizes;\n"
3234

3335
layout(push_constant) uniform restrict Block {
3436
ivec4 out_sizes;
@@ -40,90 +42,135 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4042
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4143

4244
$for i in range(NUM_INPUTS):
43-
${layout_declare_spec_const(C, "int", "in" + str(i+1) + "_layout", "DEFAULT_LAYOUT")}
44-
const lowp ivec4 in${i+1}_axis_map = unhash_axis_map(in${i+1}_layout);
45-
const lowp int in${i+1}_packed_dim = unhash_packed_dim(in${i+1}_layout);
45+
${layout_declare_spec_const(C, "int", "inp" + str(i) + "_layout", "DEFAULT_LAYOUT")}
46+
const lowp ivec4 inp${i}_axis_map = unhash_axis_map(inp${i}_layout);
47+
const lowp int inp${i}_packed_dim = unhash_packed_dim(inp${i}_layout);
4648

4749
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4850

49-
// Check if we can use the fast path (no texel merging required)
50-
bool can_use_fast_path() {
51-
// Fast path is possible when:
52-
// 1. The concat dimension is not the packed dimension, or
53-
// 2. The concat dimension is the packed dimension but both input tensors have dimensions
54-
// that are multiples of 4 along the packed dimension
55-
if (concat_dim != out_packed_dim) {
56-
return true;
57-
}
58-
59-
// Check if all input tensors have dimensions that are multiples of 4 along the packed dimension
60-
bool all_concat_dim_size_multiple_of_4 = true;
61-
$for i in range(NUM_INPUTS):
62-
all_concat_dim_size_multiple_of_4 =
63-
all_concat_dim_size_multiple_of_4 &&
64-
(in${i+1}_sizes[concat_dim] % 4 == 0);
51+
#define NUM_INPUTS ${NUM_INPUTS}
6552

66-
return all_concat_dim_size_multiple_of_4;
67-
}
53+
#include "concat_utils.glslh"
6854

55+
/*
56+
* This shader template concatenates up to NUM_INPUT input tensors to the
57+
* output tensor along the concat_dim. Elements from the input tensor will
58+
* be inserted along the output's concat_dim starting at concat_offset.
59+
*
60+
* Each thread is responsible for writing out one output texel. The data
61+
* required for the output texel may be read from multiple input texels of one
62+
* input tensor.
63+
*/
6964
void main() {
70-
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
71-
ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim);
72-
73-
if (any(greaterThanEqual(out_tidx, out_sizes))) {
65+
const int tid = ivec3(gl_GlobalInvocationID).x;
66+
67+
// Sum of the sizes of all input tensors along the concat_dim
68+
const int concat_numel = total_concat_dim_numel();
69+
70+
// The 1-3 input tensors are interpreted as one concatenated tensor ("volume")
71+
// along the concat_dim for the purposes of tensor indexing. Each thread is
72+
// responsible for writing out 4 elements along the packed dim of the output
73+
// tensor by reading the source data from the input tensor(s).
74+
ivec4 inp_volume_sizes = out_sizes;
75+
inp_volume_sizes[concat_dim] = total_concat_dim_numel();
76+
77+
// Reconstruct inp_volume_texel_sizes from Concat.cpp
78+
ivec4 inp_volume_texel_sizes = inp_volume_sizes;
79+
inp_volume_texel_sizes[out_packed_dim] = DIV_UP_4(
80+
inp_volume_texel_sizes[out_packed_dim]
81+
) + 1;
82+
83+
// tensor index of the first element that will be read from the input volume
84+
ivec4 inp_volume_start_tidx = nchwi_to_tidx(tid, inp_volume_texel_sizes);
85+
inp_volume_start_tidx[out_packed_dim] = MUL_4(
86+
inp_volume_start_tidx[out_packed_dim]
87+
);
88+
89+
int concat_offset = t_concat_offset[0];
90+
91+
// tensor index of the first element that will be written to the output tensor
92+
ivec4 out_write_start_tidx = inp_volume_start_tidx;
93+
out_write_start_tidx[concat_dim] += concat_offset;
94+
95+
// To write to the the desired output element, we will need to load the texel
96+
// to which the element belongs. Calculate the tensor index of the first
97+
// element of that texel.
98+
ivec4 out_read_start_tidx = out_write_start_tidx;
99+
out_read_start_tidx[out_packed_dim] = ALIGN_DOWN_4(
100+
out_write_start_tidx[out_packed_dim]);
101+
102+
// bounds check
103+
if (any(greaterThanEqual(out_read_start_tidx, out_sizes))) {
74104
return;
75105
}
76106

77-
if (can_use_fast_path()) {
78-
// Fast path: No texel merging required
79-
ivec4 in_tidx = out_tidx;
107+
ivec3 out_pos = tidx_to_pos(
108+
out_read_start_tidx,
109+
out_sizes,
110+
out_axis_map,
111+
out_packed_dim
112+
);
80113

81-
$for i in range(NUM_INPUTS):
82-
// For each input tensor, check if the tensor index is within bounds. If
83-
// so, read the texel from the input tensor and write it to the output
84-
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
85-
const ivec3 in_pos = tidx_to_pos(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim);
86-
const VEC4_T in_texel = load_texel(t_in${i+1}, in_pos);
87-
write_texel_lpos(t_out, lpos, in_texel, out_axis_map);
88-
return;
89-
}
90-
// Otherwise, adjust the index along the concat dimension and try the next
91-
// input tensor.
92-
else {
93-
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
94-
}
95-
}
96-
else {
97-
// Slow path: Texel merging required
98-
VEC4_T out_texel = VEC4_T(0);
114+
VEC4_T out_texel = imageLoad(t_out, out_pos);
99115

100-
// Process each element in the output texel individually
101-
for (int texel_i = 0; texel_i < 4; ++texel_i) {
102-
ivec4 curr_out_tidx = out_tidx;
103-
curr_out_tidx[out_packed_dim] += texel_i;
116+
VEC4_T test_texel = VEC4_T(-1.0);
104117

105-
// Skip if we're out of bounds
106-
if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) {
107-
continue;
108-
}
118+
for (int comp = 0; comp < 4; ++comp) {
119+
ivec4 out_tidx = out_read_start_tidx;
120+
out_tidx[out_packed_dim] += comp;
109121

110-
ivec4 in_tidx = curr_out_tidx;
111-
$for i in range(NUM_INPUTS):
112-
// For each input tensor, check if the tensor index is within bounds. If
113-
// so, read the corresponding texel element from the input tensor and
114-
// write it to the output texel.
115-
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
116-
const ivec4 in_posi = tidx_to_posi(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim);
117-
out_texel[texel_i] = load_texel(t_in${i+1}, in_posi.xyz)[in_posi.w];
118-
continue;
119-
}
120-
// Otherwise, adjust the index along the concat dimension and try the
121-
// next input tensor.
122-
else {
123-
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
124-
}
122+
123+
// It's possible that the current texel element has been written to as part
124+
// of the previous input batch; if so, then don't overwrite this texel
125+
// element
126+
if (out_tidx[concat_dim] < concat_offset) {
127+
test_texel[comp] = -5.0;
128+
continue;
125129
}
126130

127-
write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
131+
// Calculate the tidx of the input volume that corresponds to this output
132+
// element
133+
ivec4 inp_volume_tidx = out_tidx;
134+
inp_volume_tidx[concat_dim] -= concat_offset;
135+
136+
// go through the list of input tensors, and figure out which input this
137+
// output element should be read from.
138+
$for i in range(NUM_INPUTS):
139+
if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) {
140+
// Special fast path case if, for the first output texel element, the
141+
// corresponding input element is at the start of the texel it belongs
142+
// to. In this case, the input texel can be written as-is to the output
143+
// texel. Also require that The entire input texel is valid and does not
144+
// contain any padding elements.
145+
if (comp == 0 &&
146+
out_tidx[out_packed_dim] % 4 == 0 &&
147+
inp_volume_tidx[inp${i}_packed_dim] % 4 == 0 &&
148+
inp_volume_tidx[inp${i}_packed_dim] + 3 < inp${i}_sizes[inp${i}_packed_dim]) {
149+
const ivec3 in_pos = tidx_to_pos(
150+
inp_volume_tidx,
151+
inp${i}_sizes,
152+
inp${i}_axis_map,
153+
inp${i}_packed_dim);
154+
155+
out_texel = texelFetch(t_inp${i}, in_pos, 0);
156+
break;
157+
}
158+
159+
// Otherwise, locate the specific input element required
160+
const ivec4 in_posi = tidx_to_posi(
161+
inp_volume_tidx,
162+
inp${i}_sizes,
163+
inp${i}_axis_map,
164+
inp${i}_packed_dim);
165+
166+
out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w];
167+
test_texel[comp] = out_texel[comp];
168+
continue;
169+
}
170+
else {
171+
inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim];
172+
}
128173
}
174+
175+
imageStore(t_out, out_pos, out_texel);
129176
}

0 commit comments

Comments
 (0)