1414
1515#define TILE_SIZE ${TILE_SIZE}
1616
17+ #define BATCH_SIZE_Y ${BATCH_SIZE_Y}
18+
1719#define op(X, A, B) ${OPERATOR}
1820
1921#include "indexing_utils.h"
@@ -39,12 +41,20 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3941 * output at a single output location.
4042 */
4143void main() {
42- const u16vec3 pos = u16vec3(
44+ // y divided up by batch size is used to determine 3d position
45+ // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46+ const uint out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1 ) / BATCH_SIZE_Y;
47+
48+ u16vec3 pos = u16vec3(
4349 gl_GlobalInvocationID.x % out_limits.x,
44- (gl_GlobalInvocationID.x / out_limits.x) % out_limits.y ,
45- gl_GlobalInvocationID.x / (out_limits.x * out_limits.y ));
50+ (( gl_GlobalInvocationID.x / out_limits.x) % out_limits_y_scaled) ,
51+ gl_GlobalInvocationID.x / (out_limits.x * out_limits_y_scaled ));
4652
47- if (any (greaterThanEqual (pos, out_limits))) {
53+ // scale pos.y by batch size, because that's the top pixel to be processed
54+ pos.y *= uint16_t(BATCH_SIZE_Y);
55+
56+ // do not process if top pixel does not fit within the output range
57+ if (any (greaterThanEqual (u16vec3(pos.x, pos.y, pos.z), out_limits))) {
4858 return ;
4959 }
5060
@@ -57,18 +67,47 @@ void main() {
5767 const u16vec2 start = ipos;
5868 const u16vec2 end = ipos + u16vec2(overlay_region.xy);
5969
60- VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
70+ // sum outputs
71+ VEC4_T sum[BATCH_SIZE_Y];
72+
73+ sum[0 ] = texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
74+ for (int i = 1 ; i < BATCH_SIZE_Y; i++ ) {
75+ sum[i] = sum[0 ];
76+ }
77+
78+ // array to store input texels
79+ VEC4_T in_texels[TILE_SIZE];
80+
81+ // array to store kernel data of previous y
82+ VEC4_T prev_kernel_line[TILE_SIZE];
83+
6184 uint16_t kx = uint16_t(0 );
62- for (uint16_t y = start.y, i = uint16_t(0 ); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++ ) {
85+ for (uint16_t y = start.y, i = uint16_t(0 ); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1 ); y += uint16_t(dilation.y), i++ ) {
6386 for (uint16_t x = start.x, j = uint16_t(0 ); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++ ) {
64- // The weight kernel was rearranged such that every NxN filter is
65- // flattened to fit in one row. Each filter was then stacked on top of
66- // each other vertically.
67- const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0 );
68- sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0 ), sum);
69- kx++ ;
87+ in_texels[int (j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0 );
88+ }
89+
90+ // from 2nd iteration onwards accumulate dot product in 2nd sum
91+ // based on kernel line data fetched in previous iteration and input texel from this iteration
92+ if (i > uint16_t(0 )) {
93+ for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ ) {
94+ sum[1 ] = fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[1 ]);
95+ }
96+ }
97+
98+ // accumulate dot product in 1st sum only until tile size
99+ if (i < uint16_t(TILE_SIZE)) {
100+ for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ , kx++ ) {
101+ prev_kernel_line[int (j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0 );
102+ sum[0 ] = fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[0 ]);
103+ }
70104 }
71105 }
72106
73- imageStore(t_out, pos, op(sum, out_min, out_max));
107+ for (int i = 0 ; i < BATCH_SIZE_Y; i++ ) {
108+ if (any (greaterThanEqual (u16vec3(pos.x, pos.y + i, pos.z), out_limits))) {
109+ continue ;
110+ }
111+ imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max));
112+ }
74113}
0 commit comments