1414
1515#define  TILE_SIZE ${TILE_SIZE}
1616
17+ #define  BATCH_SIZE_Y ${BATCH_SIZE_Y}
18+ 
1719#define  op(X, A, B) ${OPERATOR}
1820
1921#include "indexing_utils.h"
@@ -39,12 +41,20 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3941 * output at a single output location. 
4042 */  
4143void  main() {
42-   const  u16vec3 pos =  u16vec3(
44+   //  y divided up by batch size is used to determine 3d position
45+   //  since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46+   const  uint  out_limits_y_scaled =  (out_limits.y +  BATCH_SIZE_Y -  1 ) /  BATCH_SIZE_Y;
47+ 
48+   u16vec3 pos =  u16vec3(
4349    gl_GlobalInvocationID.x %  out_limits.x,
44-     (gl_GlobalInvocationID.x /  out_limits.x) %  out_limits.y ,
45-     gl_GlobalInvocationID.x /  (out_limits.x *  out_limits.y ));
50+     (( gl_GlobalInvocationID.x /  out_limits.x) %  out_limits_y_scaled) ,
51+     gl_GlobalInvocationID.x /  (out_limits.x *  out_limits_y_scaled ));
4652
47-   if  (any (greaterThanEqual (pos, out_limits))) {
53+   //  scale pos.y by batch size, because that's the top pixel to be processed
54+   pos.y *=  uint16_t(BATCH_SIZE_Y);
55+ 
56+   //  do not process if top pixel does not fit within the output range
57+   if  (any (greaterThanEqual (u16vec3(pos.x, pos.y, pos.z), out_limits))) {
4858    return ;
4959  }
5060
@@ -57,18 +67,47 @@ void main() {
5767  const  u16vec2 start =  ipos;
5868  const  u16vec2 end =  ipos +  u16vec2(overlay_region.xy);
5969
60-   VEC4_T sum =  texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
70+   //  sum outputs
71+   VEC4_T sum[BATCH_SIZE_Y];
72+ 
73+   sum[0 ] =  texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
74+   for  (int  i =  1 ; i <  BATCH_SIZE_Y; i++ ) {
75+     sum[i] =  sum[0 ];
76+   }
77+ 
78+   //  array to store input texels
79+   VEC4_T in_texels[TILE_SIZE];
80+ 
81+   //  array to store kernel data of previous y
82+   VEC4_T prev_kernel_line[TILE_SIZE];
83+ 
6184  uint16_t kx =  uint16_t(0 );
62-   for  (uint16_t y =  start.y, i =  uint16_t(0 ); i <  uint16_t(TILE_SIZE); y +=  uint16_t(dilation.y), i++ ) {
85+   for  (uint16_t y =  start.y, i =  uint16_t(0 ); i <  uint16_t(TILE_SIZE  +  BATCH_SIZE_Y  -   1 ); y +=  uint16_t(dilation.y), i++ ) {
6386    for  (uint16_t x =  start.x, j =  uint16_t(0 ); j <  uint16_t(TILE_SIZE); x +=  uint16_t(dilation.x), j++ ) {
64-       //  The weight kernel was rearranged such that every NxN filter is
65-       //  flattened to fit in one row. Each filter was then stacked on top of
66-       //  each other vertically.
67-       const  vec4  in_texel =  texelFetch(t_in, u16vec3(x, y, pos.z), 0 );
68-       sum =  fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0 ), sum);
69-       kx++ ;
87+       in_texels[int (j)] =  texelFetch(t_in, u16vec3(x, y, pos.z), 0 );
88+     }
89+ 
90+     //  from 2nd iteration onwards accumulate dot product in 2nd sum
91+     //  based on kernel line data fetched in previous iteration and input texel from this iteration
92+     if  (i >  uint16_t(0 )) {
93+       for  (uint16_t j =  uint16_t(0 ); j <  uint16_t(TILE_SIZE); j++ ) {
94+         sum[1 ] =  fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[1 ]);
95+       }
96+     }
97+ 
98+     //  accumulate dot product in 1st sum only until tile size
99+     if  (i <  uint16_t(TILE_SIZE)) {
100+       for  (uint16_t j =  uint16_t(0 ); j <  uint16_t(TILE_SIZE); j++ , kx++ ) {
101+         prev_kernel_line[int (j)] =  texelFetch(t_kernel, u16vec2(kx, pos.z), 0 );
102+         sum[0 ] =  fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[0 ]);
103+       }
70104    }
71105  }
72106
73-   imageStore(t_out, pos, op(sum, out_min, out_max));
107+   for  (int  i =  0 ; i <  BATCH_SIZE_Y; i++ ) {
108+     if  (any (greaterThanEqual (u16vec3(pos.x, pos.y +  i, pos.z), out_limits))) {
109+       continue ;
110+     }
111+     imageStore(t_out, u16vec3(pos.x, pos.y +  i, pos.z), op(sum[i], out_min, out_max));
112+   }
74113}
0 commit comments