@@ -32,24 +32,26 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3232
3333layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3434
35+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
36+
3537/*
3638 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an
3739 * output tile for pointwise convolution is more efficient because the kernel
3840 * size is only 1x1, making it easier to re-use loaded texels from t_kernel.
3941 */
4042void main() {
41- const ivec3 gpos = ivec3 (gl_GlobalInvocationID);
43+ const u16vec3 gpos = u16vec3 (gl_GlobalInvocationID);
4244
4345 // Output position for TILE_SIZE = 2
4446 // +--------+--------+
4547 // | pos[0] | pos[1] |
4648 // +--------+--------+
4749 // | pos[2] | pos[3] |
4850 // +--------+--------+
49- ivec3 pos[TILE_SIZE * TILE_SIZE];
51+ u16vec3 pos[TILE_SIZE * TILE_SIZE];
5052 for (int y = 0 , i = 0 ; y < TILE_SIZE; ++ y) {
5153 for (int x = 0 ; x < TILE_SIZE; ++ x) {
52- pos[i] = ivec3 (
54+ pos[i] = u16vec3 (
5355 gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y, gpos.z);
5456 i++ ;
5557 }
@@ -64,13 +66,13 @@ void main() {
6466 // Compute the index of the input texture that needs to be loaded for each
6567 // output position. Note that negative indices can be produced indicating that
6668 // the top-left element is in a region added by padding.
67- ivec2 ipos[TILE_SIZE * TILE_SIZE];
69+ u16vec2 ipos[TILE_SIZE * TILE_SIZE];
6870 for (int i = 0 ; i < TILE_SIZE * TILE_SIZE; ++ i) {
69- ipos[i] = pos[i].xy * stride - padding;
71+ ipos[i] = pos[i].xy * u16vec2( stride) - u16vec2( padding) ;
7072 }
7173
7274 vec4 sum[TILE_SIZE * TILE_SIZE];
73- sum[0 ] = texelFetch(t_bias, ivec2 (gpos.z, 0 ), 0 );
75+ sum[0 ] = texelFetch(t_bias, u16vec2 (gpos.z, 0 ), 0 );
7476 for (int i = 1 ; i < TILE_SIZE * TILE_SIZE; ++ i) {
7577 sum[i] = sum[0 ];
7678 }
@@ -81,13 +83,13 @@ void main() {
8183 // channel (IC) dim is along the x-axis, and the batch (OC) dim is along
8284 // the z-axis.
8385 vec4 in_tex[TILE_SIZE * TILE_SIZE];
84- const vec4 ktex_0 = texelFetch(t_kernel, ivec2 (z + 0 , gpos.z), 0 );
85- const vec4 ktex_1 = texelFetch(t_kernel, ivec2 (z + 1 , gpos.z), 0 );
86- const vec4 ktex_2 = texelFetch(t_kernel, ivec2 (z + 2 , gpos.z), 0 );
87- const vec4 ktex_3 = texelFetch(t_kernel, ivec2 (z + 3 , gpos.z), 0 );
86+ const vec4 ktex_0 = texelFetch(t_kernel, u16vec2 (z + 0 , gpos.z), 0 );
87+ const vec4 ktex_1 = texelFetch(t_kernel, u16vec2 (z + 1 , gpos.z), 0 );
88+ const vec4 ktex_2 = texelFetch(t_kernel, u16vec2 (z + 2 , gpos.z), 0 );
89+ const vec4 ktex_3 = texelFetch(t_kernel, u16vec2 (z + 3 , gpos.z), 0 );
8890
8991 for (int i = 0 ; i < TILE_SIZE * TILE_SIZE; ++ i) {
90- in_tex[i] = texelFetch(t_in, ivec3 (ipos[i], z4), 0 );
92+ in_tex[i] = texelFetch(t_in, u16vec3 (ipos[i], z4), 0 );
9193 }
9294
9395 for (int i = 0 ; i < TILE_SIZE * TILE_SIZE; ++ i) {
0 commit comments