@@ -20,19 +20,21 @@ layout(std430) buffer;
2020
2121#include "indexing_utils.h"
2222
23- ${layout_declare_tensor(B, "w ", "t_out", DTYPE, "buffer ")}
23+ ${layout_declare_tensor(B, "rw ", "t_out", DTYPE, "buffer ")}
2424
2525$for i in range(NUM_INPUTS):
26- ${layout_declare_tensor(B, "r", "t_in" + str(i + 1 ), DTYPE, "buffer ")}
26+ ${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "buffer ")}
27+
28+ ${layout_declare_tensor(B, "r", "t_concat_offset", "int ", "buffer ")}
2729
2830${layout_declare_ubo(B, "int ", "concat_dim")}
2931
3032${layout_declare_ubo(B, "ivec4 ", "out_sizes")}
3133${layout_declare_ubo(B, "ivec4 ", "out_strides")}
3234
3335$for i in range(NUM_INPUTS):
34- ${layout_declare_ubo(B, "ivec4 ", "in " + str(i+ 1 ) + "_sizes")}
35- ${layout_declare_ubo(B, "ivec4 ", "in " + str(i+ 1 ) + "_strides")}
36+ ${layout_declare_ubo(B, "ivec4 ", "inp " + str(i) + "_sizes")}
37+ ${layout_declare_ubo(B, "ivec4 ", "inp " + str(i) + "_strides")}
3638
3739${layout_declare_ubo(B, "int ", "out_numel")}
3840
@@ -42,28 +44,53 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
4244
4345layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4446
47+ #define NUM_INPUTS ${NUM_INPUTS}
48+
49+ #include "concat_utils.glslh"
50+
51+ /*
52+ * This shader template concatenates up to NUM_INPUT input tensors to the
53+ * output tensor along the concat_dim. Elements from the input tensor will
54+ * be inserted along the output's concat_dim starting at concat_offset.
55+ */
4556void main() {
46- const int out_bufi = ivec3 (gl_GlobalInvocationID).x;
47- if (out_bufi >= out_numel) {
57+ const int tid = ivec3 (gl_GlobalInvocationID).x;
58+
59+ // The 1-3 input tensors are interpreted as one concatenated tensor ("volume")
60+ // along the concat_dim for the purposes of tensor indexing. Each thread is
61+ // responsible for reading one item from this volume and writing it to the
62+ // appropriate output location.
63+ ivec4 inp_volume_sizes = out_sizes;
64+ inp_volume_sizes[concat_dim] = total_concat_dim_numel();
65+
66+ // Account for 0 size input tensors
67+ if (any (lessThanEqual (inp_volume_sizes, ivec4 (0 )))) {
68+ return ;
69+ }
70+
71+ ivec4 inp_volume_tidx = nchwi_to_tidx(tid, inp_volume_sizes);
72+
73+ // bounds check
74+ if (any (greaterThanEqual (inp_volume_tidx, inp_volume_sizes))) {
4875 return ;
4976 }
5077
51- // Convert buffer linear index to 4-D tensor index for output
52- const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
78+ int concat_offset = t_concat_offset[0 ];
79+
80+ ivec4 out_tidx = inp_volume_tidx;
81+ out_tidx[concat_dim] += concat_offset;
5382
54- // Determine which input tensor to read from
55- ivec4 in_tidx = out_tidx;
83+ const uint out_bufi = tidx_to_bufi(out_tidx, out_strides);
5684
85+ // Go through the list of input tensors, and find which input this output
86+ // element should be read from.
5787 $for i in range(NUM_INPUTS):
58- // Check if the index at the concat dim is within bounds of the input tensor
59- // If so, read from that input tensor and write to output
60- if (in_tidx[concat_dim] < in ${i+ 1 }_sizes[concat_dim]) {
61- int in_bufi = tidx_to_bufi(in_tidx, in ${i+ 1 }_strides);
62- t_out[out_bufi] = t_in${i+ 1 }[in_bufi];
88+ if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) {
89+ int inp_bufi = tidx_to_bufi(inp_volume_tidx, inp${i}_strides);
90+ t_out[out_bufi] = t_inp${i}[inp_bufi];
6391 return ;
6492 }
65- // otherwise, decrement the index at the concat dim
6693 else {
67- in_tidx [concat_dim] -= in ${i+ 1 }_sizes[concat_dim];
94+ inp_volume_tidx [concat_dim] -= inp ${i}_sizes[concat_dim];
6895 }
6996}
0 commit comments