1212
1313#define VEC4_T ${texel_type(DTYPE)}
1414
15- #define TILE_SIZE ${TILE_SIZE}
15+ #define TILE_SIZE_X ${TILE_SIZE_X}
16+ #define TILE_SIZE_Y ${TILE_SIZE_Y}
17+ #define LOCAL_WG_SIZE 64
1618
1719#define op(X, A, B) ${OPERATOR}
1820
@@ -24,27 +26,36 @@ ${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
2426${layout_declare_tensor(1 , "r", "t_in", DTYPE, "texture3d")}
2527${layout_declare_tensor(2 , "r", "t_kernel", DTYPE, "texture2d")}
2628${layout_declare_tensor(3 , "r", "t_bias", DTYPE, "texture2d")}
27- ${layout_declare_ubo(4 , "ivec3 ", "out_limits")}
28- ${layout_declare_ubo(5 , "ivec4 ", "in_sizes")}
29- ${layout_declare_ubo(6 , "ivec2 ", "kernel_size", "ivec2 ", "stride", "ivec2 ", "padding", "ivec2 ", "dilation")}
30- ${layout_declare_ubo(7 , "ivec2 ", "overlay_region", "int ", "in_group_size")}
31- ${layout_declare_ubo(8 , "float ", "out_min", "float ", "out_max")}
29+
30+ layout (push_constant) uniform restrict Block {
31+ ivec4 out_limits;
32+ ivec4 in_sizes;
33+ ivec2 kernel_size;
34+ ivec2 stride;
35+ ivec2 padding;
36+ ivec2 dilation;
37+ ivec2 overlay_region;
38+ int in_group_size;
39+ int dummy_padding;
40+ float out_min;
41+ float out_max;
42+ };
3243
3344layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3445
35- // shared memory to hold calculated positions, this would reduce register usage thus improving performance .
36- // 64 is the number of threads in the local wg
37- $num_shared = 64 * TILE_SIZE * TILE_SIZE
38- shared ivec2 pos_shared[${num_shared} ];
46+ // For performance improvement, reduce register usage by caching positions in shared memory .
47+ // Offset index by 1 every 16 points to avoid bank access conflict.
48+ #define offset_pos_index(index) (index + ((index) >> 4 ))
49+ shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y) ];
3950
4051/*
4152 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an
4253 * output tile for pointwise convolution is more efficient because the kernel
4354 * size is only 1x1, making it easier to re-use loaded texels from t_kernel.
4455 */
4556void main() {
46- const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1 ) / TILE_SIZE ;
47- const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z ;
57+ const ivec2 out_limits_scaled = (out_limits.xy + ivec2 (TILE_SIZE_X - 1 , TILE_SIZE_Y - 1 )) / ivec2 (TILE_SIZE_X, TILE_SIZE_Y) ;
58+ const uint shared_mem_stride = LOCAL_WG_SIZE ;
4859
4960 const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
5061 const ivec3 gpos = ivec3 (
@@ -58,33 +69,32 @@ void main() {
5869 // +--------+--------+
5970 // | pos[2] | pos[3] |
6071 // +--------+--------+
61- ivec2 pos[TILE_SIZE * TILE_SIZE];
62- for (int y = 0 , i = 0 ; y < TILE_SIZE; ++ y) {
63- for (int x = 0 ; x < TILE_SIZE; ++ x) {
64- pos[i] = ivec2 (
65- gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
66- pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
72+ ivec2 pos[TILE_SIZE_X * TILE_SIZE_Y];
73+ for (int y = 0 , i = 0 ; y < TILE_SIZE_Y; ++ y) {
74+ for (int x = 0 ; x < TILE_SIZE_X; ++ x) {
75+ pos[i] = ivec2 (gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
76+ pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3 (pos[i], gpos.z);
6777 i++ ;
6878 }
6979 }
7080
7181 // If the top left position is out of bounds, then this invocation will have
7282 // no work to do.
73- if (any ( greaterThanEqual ( ivec3 (pos[ 0 ], gpos.z), out_limits)) ) {
83+ if (gpos.z >= out_limits.z ) {
7484 return ;
7585 }
7686
7787 // Compute the index of the input texture that needs to be loaded for each
7888 // output position. Note that negative indices can be produced indicating that
7989 // the top-left element is in a region added by padding.
80- ivec2 ipos[TILE_SIZE * TILE_SIZE ];
81- for (int i = 0 ; i < TILE_SIZE * TILE_SIZE ; ++ i) {
90+ ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y ];
91+ for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y ; ++ i) {
8292 ipos[i] = pos[i] * stride - padding;
8393 }
8494
85- vec4 sum[TILE_SIZE * TILE_SIZE ];
95+ vec4 sum[TILE_SIZE_X * TILE_SIZE_Y ];
8696 sum[0 ] = texelFetch(t_bias, ivec2 (gpos.z, 0 ), 0 );
87- for (int i = 1 ; i < TILE_SIZE * TILE_SIZE ; ++ i) {
97+ for (int i = 1 ; i < TILE_SIZE_X * TILE_SIZE_Y ; ++ i) {
8898 sum[i] = sum[0 ];
8999 }
90100
@@ -100,7 +110,7 @@ void main() {
100110 const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (3 , 0 ));
101111
102112#pragma unroll
103- for (int i = 0 ; i < TILE_SIZE * TILE_SIZE ; ++ i) {
113+ for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y ; ++ i) {
104114 const vec4 in_tex = texelFetch(t_in, ivec3 (ipos[i], z4), 0 );
105115 // For 2x2 tile size algorithm works as follows.
106116 // To explain the calculations below, the contents of one in_tex and the
@@ -142,10 +152,11 @@ void main() {
142152 }
143153 }
144154
145- for (int i = 0 ; i < TILE_SIZE * TILE_SIZE; ++ i) {
146- const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
147- if (all (lessThan (ivec3 (pos, gpos.z), out_limits))) {
148- imageStore(t_out, ivec3 (pos, gpos.z), op(sum[i], out_min, out_max));
155+ for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
156+ const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
157+ const ivec3 pos = pos_shared[offset_pos_index(index)];
158+ if (all (lessThan (pos, out_limits.xyz))) {
159+ imageStore(t_out, pos, op(sum[i], out_min, out_max));
149160 }
150161 }
151162}
0 commit comments