@@ -19,16 +19,18 @@ layout(std430) buffer;
19
19
20
20
#include "indexing_utils.h"
21
21
22
- ${layout_declare_tensor(B, "w ", "t_out", DTYPE, "texture3d")}
22
+ ${layout_declare_tensor(B, "rw ", "t_out", DTYPE, "texture3d")}
23
23
24
24
$for i in range(NUM_INPUTS):
25
- ${layout_declare_tensor(B, "r", "t_in" + str(i + 1 ), DTYPE, "texture3d")}
25
+ ${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "texture3d")}
26
+
27
+ ${layout_declare_tensor(B, "r", "t_concat_offset", "int ", "buffer ")}
26
28
27
29
${layout_declare_ubo(B, "int ", "concat_dim")}
28
30
29
31
$in_metadata = ""
30
32
$for i in range(NUM_INPUTS):
31
- $in_metadata += "ivec4 in " + str(i + 1 ) + "_sizes;\n"
33
+ $in_metadata += "ivec4 inp " + str(i) + "_sizes;\n"
32
34
33
35
layout (push_constant) uniform restrict Block {
34
36
ivec4 out_sizes;
@@ -40,90 +42,135 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
40
42
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
41
43
42
44
$for i in range(NUM_INPUTS):
43
- ${layout_declare_spec_const(C, "int ", "in " + str(i+ 1 ) + "_layout", "DEFAULT_LAYOUT")}
44
- const lowp ivec4 in ${i+ 1 }_axis_map = unhash_axis_map(in ${i+ 1 }_layout);
45
- const lowp int in ${i+ 1 }_packed_dim = unhash_packed_dim(in ${i+ 1 }_layout);
45
+ ${layout_declare_spec_const(C, "int ", "inp " + str(i) + "_layout", "DEFAULT_LAYOUT")}
46
+ const lowp ivec4 inp ${i}_axis_map = unhash_axis_map(inp ${i}_layout);
47
+ const lowp int inp ${i}_packed_dim = unhash_packed_dim(inp ${i}_layout);
46
48
47
49
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
48
50
49
- // Check if we can use the fast path (no texel merging required)
50
- bool can_use_fast_path() {
51
- // Fast path is possible when:
52
- // 1. The concat dimension is not the packed dimension, or
53
- // 2. The concat dimension is the packed dimension but both input tensors have dimensions
54
- // that are multiples of 4 along the packed dimension
55
- if (concat_dim != out_packed_dim) {
56
- return true;
57
- }
58
-
59
- // Check if all input tensors have dimensions that are multiples of 4 along the packed dimension
60
- bool all_concat_dim_size_multiple_of_4 = true;
61
- $for i in range(NUM_INPUTS):
62
- all_concat_dim_size_multiple_of_4 =
63
- all_concat_dim_size_multiple_of_4 &&
64
- (in ${i+ 1 }_sizes[concat_dim] % 4 == 0 );
51
+ #define NUM_INPUTS ${NUM_INPUTS}
65
52
66
- return all_concat_dim_size_multiple_of_4;
67
- }
53
+ #include "concat_utils.glslh"
68
54
55
+ /*
56
+ * This shader template concatenates up to NUM_INPUT input tensors to the
57
+ * output tensor along the concat_dim. Elements from the input tensor will
58
+ * be inserted along the output's concat_dim starting at concat_offset.
59
+ *
60
+ * Each thread is responsible for writing out one output texel. The data
61
+ * required for the output texel may be read from multiple input texels of one
62
+ * input tensor.
63
+ */
69
64
void main() {
70
- const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
71
- ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim);
72
-
73
- if (any (greaterThanEqual (out_tidx, out_sizes))) {
65
+ const int tid = ivec3 (gl_GlobalInvocationID).x;
66
+
67
+ // Sum of the sizes of all input tensors along the concat_dim
68
+ const int concat_numel = total_concat_dim_numel();
69
+
70
+ // The 1-3 input tensors are interpreted as one concatenated tensor ("volume")
71
+ // along the concat_dim for the purposes of tensor indexing. Each thread is
72
+ // responsible for writing out 4 elements along the packed dim of the output
73
+ // tensor by reading the source data from the input tensor(s).
74
+ ivec4 inp_volume_sizes = out_sizes;
75
+ inp_volume_sizes[concat_dim] = total_concat_dim_numel();
76
+
77
+ // Reconstruct inp_volume_texel_sizes from Concat.cpp
78
+ ivec4 inp_volume_texel_sizes = inp_volume_sizes;
79
+ inp_volume_texel_sizes[out_packed_dim] = DIV_UP_4(
80
+ inp_volume_texel_sizes[out_packed_dim]
81
+ ) + 1 ;
82
+
83
+ // tensor index of the first element that will be read from the input volume
84
+ ivec4 inp_volume_start_tidx = nchwi_to_tidx(tid, inp_volume_texel_sizes);
85
+ inp_volume_start_tidx[out_packed_dim] = MUL_4(
86
+ inp_volume_start_tidx[out_packed_dim]
87
+ );
88
+
89
+ int concat_offset = t_concat_offset[0 ];
90
+
91
+ // tensor index of the first element that will be written to the output tensor
92
+ ivec4 out_write_start_tidx = inp_volume_start_tidx;
93
+ out_write_start_tidx[concat_dim] += concat_offset;
94
+
95
+ // To write to the the desired output element, we will need to load the texel
96
+ // to which the element belongs. Calculate the tensor index of the first
97
+ // element of that texel.
98
+ ivec4 out_read_start_tidx = out_write_start_tidx;
99
+ out_read_start_tidx[out_packed_dim] = ALIGN_DOWN_4(
100
+ out_write_start_tidx[out_packed_dim]);
101
+
102
+ // bounds check
103
+ if (any (greaterThanEqual (out_read_start_tidx, out_sizes))) {
74
104
return ;
75
105
}
76
106
77
- if (can_use_fast_path()) {
78
- // Fast path: No texel merging required
79
- ivec4 in_tidx = out_tidx;
107
+ ivec3 out_pos = tidx_to_pos(
108
+ out_read_start_tidx,
109
+ out_sizes,
110
+ out_axis_map,
111
+ out_packed_dim
112
+ );
80
113
81
- $for i in range(NUM_INPUTS):
82
- // For each input tensor, check if the tensor index is within bounds. If
83
- // so, read the texel from the input tensor and write it to the output
84
- if (in_tidx[concat_dim] < in ${i+ 1 }_sizes[concat_dim]) {
85
- const ivec3 in_pos = tidx_to_pos(in_tidx, in ${i+ 1 }_sizes, in ${i+ 1 }_axis_map, in ${i+ 1 }_packed_dim);
86
- const VEC4_T in_texel = load_texel(t_in${i+ 1 }, in_pos);
87
- write_texel_lpos(t_out, lpos, in_texel, out_axis_map);
88
- return ;
89
- }
90
- // Otherwise, adjust the index along the concat dimension and try the next
91
- // input tensor.
92
- else {
93
- in_tidx[concat_dim] -= in ${i+ 1 }_sizes[concat_dim];
94
- }
95
- }
96
- else {
97
- // Slow path: Texel merging required
98
- VEC4_T out_texel = VEC4_T(0 );
114
+ VEC4_T out_texel = imageLoad(t_out, out_pos);
99
115
100
- // Process each element in the output texel individually
101
- for (int texel_i = 0 ; texel_i < 4 ; ++ texel_i) {
102
- ivec4 curr_out_tidx = out_tidx;
103
- curr_out_tidx[out_packed_dim] += texel_i;
116
+ VEC4_T test_texel = VEC4_T(- 1.0 );
104
117
105
- // Skip if we're out of bounds
106
- if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) {
107
- continue ;
108
- }
118
+ for (int comp = 0 ; comp < 4 ; ++ comp) {
119
+ ivec4 out_tidx = out_read_start_tidx;
120
+ out_tidx[out_packed_dim] += comp;
109
121
110
- ivec4 in_tidx = curr_out_tidx;
111
- $for i in range(NUM_INPUTS):
112
- // For each input tensor, check if the tensor index is within bounds. If
113
- // so, read the corresponding texel element from the input tensor and
114
- // write it to the output texel.
115
- if (in_tidx[concat_dim] < in ${i+ 1 }_sizes[concat_dim]) {
116
- const ivec4 in_posi = tidx_to_posi(in_tidx, in ${i+ 1 }_sizes, in ${i+ 1 }_axis_map, in ${i+ 1 }_packed_dim);
117
- out_texel[texel_i] = load_texel(t_in${i+ 1 }, in_posi.xyz)[in_posi.w];
118
- continue ;
119
- }
120
- // Otherwise, adjust the index along the concat dimension and try the
121
- // next input tensor.
122
- else {
123
- in_tidx[concat_dim] -= in ${i+ 1 }_sizes[concat_dim];
124
- }
122
+
123
+ // It's possible that the current texel element has been written to as part
124
+ // of the previous input batch; if so, then don't overwrite this texel
125
+ // element
126
+ if (out_tidx[concat_dim] < concat_offset) {
127
+ test_texel[comp] = - 5.0 ;
128
+ continue ;
125
129
}
126
130
127
- write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
131
+ // Calculate the tidx of the input volume that corresponds to this output
132
+ // element
133
+ ivec4 inp_volume_tidx = out_tidx;
134
+ inp_volume_tidx[concat_dim] -= concat_offset;
135
+
136
+ // go through the list of input tensors, and figure out which input this
137
+ // output element should be read from.
138
+ $for i in range(NUM_INPUTS):
139
+ if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) {
140
+ // Special fast path case if, for the first output texel element, the
141
+ // corresponding input element is at the start of the texel it belongs
142
+ // to. In this case, the input texel can be written as-is to the output
143
+ // texel. Also require that The entire input texel is valid and does not
144
+ // contain any padding elements.
145
+ if (comp == 0 &&
146
+ out_tidx[out_packed_dim] % 4 == 0 &&
147
+ inp_volume_tidx[inp${i}_packed_dim] % 4 == 0 &&
148
+ inp_volume_tidx[inp${i}_packed_dim] + 3 < inp${i}_sizes[inp${i}_packed_dim]) {
149
+ const ivec3 in_pos = tidx_to_pos(
150
+ inp_volume_tidx,
151
+ inp${i}_sizes,
152
+ inp${i}_axis_map,
153
+ inp${i}_packed_dim);
154
+
155
+ out_texel = texelFetch(t_inp${i}, in_pos, 0 );
156
+ break ;
157
+ }
158
+
159
+ // Otherwise, locate the specific input element required
160
+ const ivec4 in_posi = tidx_to_posi(
161
+ inp_volume_tidx,
162
+ inp${i}_sizes,
163
+ inp${i}_axis_map,
164
+ inp${i}_packed_dim);
165
+
166
+ out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0 )[in_posi.w];
167
+ test_texel[comp] = out_texel[comp];
168
+ continue ;
169
+ }
170
+ else {
171
+ inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim];
172
+ }
128
173
}
174
+
175
+ imageStore(t_out, out_pos, out_texel);
129
176
}
0 commit comments