@@ -32,35 +32,37 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3232
3333layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3434
35+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
36+
3537/*
3638 * Computes a depthwise convolution. Each shader invocation calculates the
3739 * output at a single output location.
3840 */
3941void main() {
40- const ivec3 pos = ivec3 (gl_GlobalInvocationID);
42+ const u16vec3 pos = u16vec3 (gl_GlobalInvocationID);
4143
4244 if (any (greaterThanEqual (pos, out_limits))) {
4345 return ;
4446 }
4547
4648 // Compute the index of the top-left element of the overlay region. Negative
4749 // indices indicate that the top-left element is in a region added by padding.
48- const ivec2 ipos = pos.xy * stride - padding;
50+ const u16vec2 ipos = pos.xy * u16vec2( stride) - u16vec2( padding) ;
4951
5052 // Compute the start and end of the input indices to load. Padding is assumed
5153 // to be constant 0 padding, so any reads from the padding region is skipped.
52- const ivec2 start = ipos;
53- const ivec2 end = ipos + overlay_region.xy;
54+ const u16vec2 start = ipos;
55+ const u16vec2 end = ipos + u16vec2( overlay_region.xy) ;
5456
55- VEC4_T sum = texelFetch(t_bias, ivec2 (pos.z, 0 ), 0 );
56- int kx = 0 ;
57- for (int y = start.y, i = 0 ; i < TILE_SIZE; y += dilation.y, i++ ) {
58- for (int x = start.x, j = 0 ; j < TILE_SIZE; x += dilation.x, j++ ) {
57+ VEC4_T sum = texelFetch(t_bias, u16vec2 (pos.z, 0 ), 0 );
58+ uint16_t kx = uint16_t( 0 ) ;
59+ for (uint16_t y = start.y, i = uint16_t( 0 ) ; i < uint16_t( TILE_SIZE) ; y += uint16_t( dilation.y) , i++ ) {
60+ for (uint16_t x = start.x, j = uint16_t( 0 ) ; j < uint16_t( TILE_SIZE) ; x += uint16_t( dilation.x) , j++ ) {
5961 // The weight kernel was rearranged such that every NxN filter is
6062 // flattened to fit in one row. Each filter was then stacked on top of
6163 // each other vertically.
62- const vec4 in_texel = texelFetch(t_in, ivec3 (x, y, pos.z), 0 );
63- sum = fma(in_texel, texelFetch(t_kernel, ivec2 (kx, pos.z), 0 ), sum);
64+ const vec4 in_texel = texelFetch(t_in, u16vec3 (x, y, pos.z), 0 );
65+ sum = fma(in_texel, texelFetch(t_kernel, u16vec2 (kx, pos.z), 0 ), sum);
6466 kx++ ;
6567 }
6668 }
0 commit comments