1414
1515#define TILE_SIZE ${TILE_SIZE}
1616
17+ #define BATCH_SIZE_X ${BATCH_SIZE_X}
18+
1719#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
1820
1921#define op(X, A, B) ${OPERATOR}
2022
21- #include "indexing_utils_u16 .h"
23+ #include "indexing_utils .h"
2224
2325layout (std430) buffer ;
2426
@@ -41,70 +43,79 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4143 * output at a single output location.
4244 */
4345void main() {
44- // y divided up by batch size is used to determine 3d position
46+ // x and y are divided by batch size to determine 3d position
4547 // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46- const int out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1 ) / BATCH_SIZE_Y;
48+ const ivec2 out_limits_xy_scaled = (out_limits.xy + ivec2 (BATCH_SIZE_X, BATCH_SIZE_Y) - 1 ) / ivec2 (BATCH_SIZE_X, BATCH_SIZE_Y) ;
4749
48- u16vec3 pos = idx_to_u16pos_x_wise (gl_GlobalInvocationID.x, out_limits .x, out_limits_y_scaled );
50+ ivec3 pos = idx_to_ipos_x_wise (gl_GlobalInvocationID.x, out_limits_xy_scaled .x, out_limits_xy_scaled.y );
4951
50- // scale pos.y by batch size, because that's the top pixel to be processed
51- pos.y *= uint16_t(BATCH_SIZE_Y);
52+ // scale pos.xy by batch sizes, because that's the top pixel to be processed
53+ pos.x *= BATCH_SIZE_X;
54+ pos.y *= BATCH_SIZE_Y;
5255
5356 // do not process if top pixel does not fit within the output range
54- if (any (greaterThanEqual (u16vec3( pos.x, pos.y, pos.z) , out_limits))) {
57+ if (any (greaterThanEqual (pos, out_limits))) {
5558 return ;
5659 }
5760
5861 // Compute the index of the top-left element of the overlay region. Negative
5962 // indices indicate that the top-left element is in a region added by padding.
60- const u16vec2 ipos = pos.xy * u16vec2( stride) - u16vec2( padding) ;
63+ const ivec2 ipos = pos.xy * stride - padding;
6164
6265 // Compute the start and end of the input indices to load. Padding is assumed
6366 // to be constant 0 padding, so any reads from the padding region is skipped.
64- const u16vec2 start = ipos;
65- const u16vec2 end = ipos + u16vec2( overlay_region.xy) ;
67+ const ivec2 start = ipos;
68+ const ivec2 end = ipos + overlay_region.xy;
6669
6770 // sum outputs
68- VEC4_T sum[BATCH_SIZE_Y];
71+ VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X] ;
6972
70- sum[0 ] = texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
71- for (int i = 1 ; i < BATCH_SIZE_Y; i++ ) {
72- sum[i] = sum[0 ];
73+ sum[0 ][0 ] = texelFetch(t_bias, ivec2 (pos.z, 0 ), 0 );
74+ for (int y = 0 ; y < BATCH_SIZE_Y; y++ ) {
75+ for (int x = 0 ; x < BATCH_SIZE_X; x++ ) {
76+ sum[y][x] = sum[0 ][0 ];
77+ }
7378 }
7479
7580 // array to store input texels
76- VEC4_T in_texels[TILE_SIZE];
81+ VEC4_T in_texels[TILE_SIZE + BATCH_SIZE_X - 1 ];
7782
7883 // array to store kernel data of previous y
7984 VEC4_T prev_kernel_line[TILE_SIZE];
8085
81- uint16_t kx = uint16_t( 0 ) ;
82- for (uint16_t y = start.y, i = uint16_t( 0 ) ; i < uint16_t( TILE_SIZE + BATCH_SIZE_Y - 1 ) ; y += uint16_t( dilation.y) , i++ ) {
83- for (uint16_t x = start.x, j = uint16_t( 0 ) ; j < uint16_t( TILE_SIZE) ; x += uint16_t( dilation.x) , j++ ) {
84- in_texels[int (j) ] = texelFetch(t_in, u16vec3 (x, y, pos.z), 0 );
86+ int kx = 0 ;
87+ for (int y = start.y, i = 0 ; i < TILE_SIZE + BATCH_SIZE_Y - 1 ; y += dilation.y, i++ ) {
88+ for (int x = start.x, j = 0 ; j < TILE_SIZE + BATCH_SIZE_X - 1 ; x += dilation.x, j++ ) {
89+ in_texels[j ] = texelFetch(t_in, ivec3 (x, y, pos.z), 0 );
8590 }
8691
8792 // from 2nd iteration onwards accumulate dot product in 2nd sum
8893 // based on kernel line data fetched in previous iteration and input texel from this iteration
89- if (i > uint16_t(0 )) {
90- for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ ) {
91- sum[1 ] = fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[1 ]);
94+ if (i > 0 ) {
95+ for (int j = 0 ; j < TILE_SIZE; j++ ) {
96+ for (int s = 0 ; s < BATCH_SIZE_X; s++ ) {
97+ sum[1 ][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[1 ][s]);
98+ }
9299 }
93100 }
94101
95102 // accumulate dot product in 1st sum only until tile size
96- if (i < uint16_t(TILE_SIZE)) {
97- for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ , kx++ ) {
98- prev_kernel_line[int (j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0 );
99- sum[0 ] = fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[0 ]);
103+ if (i < TILE_SIZE) {
104+ for (int j = 0 ; j < TILE_SIZE; j++ , kx++ ) {
105+ prev_kernel_line[j] = texelFetch(t_kernel, ivec2 (kx, pos.z), 0 );
106+ for (int s = 0 ; s < BATCH_SIZE_X; s++ ) {
107+ sum[0 ][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0 ][s]);
108+ }
100109 }
101110 }
102111 }
103112
104- for (int i = 0 ; i < BATCH_SIZE_Y; i++ ) {
105- if (any (greaterThanEqual (u16vec3(pos.x, pos.y + i, pos.z), out_limits))) {
106- continue ;
113+ for (int y = 0 ; y < BATCH_SIZE_Y; y++ ) {
114+ for (int x = 0 ; x < BATCH_SIZE_X; x++ ) {
115+ if (any (greaterThanEqual (ivec3 (pos.x + x, pos.y + y, pos.z), out_limits))) {
116+ continue ;
117+ }
118+ imageStore(t_out, ivec3 (pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max));
107119 }
108- imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max));
109120 }
110121}
0 commit comments