@@ -32,35 +32,37 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3232
3333layout (local_size_x_id =  0 , local_size_y_id =  1 , local_size_z_id =  2 ) in ;
3434
35+ #extension  GL_EXT_shader_explicit_arithmetic_types_int16 :  require
36+ 
3537/* 
3638 * Computes a depthwise convolution. Each shader invocation calculates the 
3739 * output at a single output location. 
3840 */  
3941void  main() {
40-   const  ivec3  pos =  ivec3 (gl_GlobalInvocationID);
42+   const  u16vec3  pos =  u16vec3 (gl_GlobalInvocationID);
4143
4244  if  (any (greaterThanEqual (pos, out_limits))) {
4345    return ;
4446  }
4547
4648  //  Compute the index of the top-left element of the overlay region. Negative
4749  //  indices indicate that the top-left element is in a region added by padding.
48-   const  ivec2  ipos =  pos.xy *  stride -  padding;
50+   const  u16vec2  ipos =  pos.xy *  u16vec2( stride)  -  u16vec2( padding) ;
4951
5052  //  Compute the start and end of the input indices to load. Padding is assumed
5153  //  to be constant 0 padding, so any reads from the padding region is skipped.
52-   const  ivec2  start =  ipos;
53-   const  ivec2  end =  ipos +  overlay_region.xy;
54+   const  u16vec2  start =  ipos;
55+   const  u16vec2  end =  ipos +  u16vec2( overlay_region.xy) ;
5456
55-   VEC4_T sum =  texelFetch(t_bias, ivec2 (pos.z, 0 ), 0 );
56-   int  kx =  0 ;
57-   for  (int  y =  start.y, i =  0 ; i <  TILE_SIZE; y +=  dilation.y, i++ ) {
58-     for  (int  x =  start.x, j =  0 ; j <  TILE_SIZE; x +=  dilation.x, j++ ) {
57+   VEC4_T sum =  texelFetch(t_bias, u16vec2 (pos.z, 0 ), 0 );
58+   uint16_t  kx =  uint16_t( 0 ) ;
59+   for  (uint16_t  y =  start.y, i =  uint16_t( 0 ) ; i <  uint16_t( TILE_SIZE) ; y +=  uint16_t( dilation.y) , i++ ) {
60+     for  (uint16_t  x =  start.x, j =  uint16_t( 0 ) ; j <  uint16_t( TILE_SIZE) ; x +=  uint16_t( dilation.x) , j++ ) {
5961      //  The weight kernel was rearranged such that every NxN filter is
6062      //  flattened to fit in one row. Each filter was then stacked on top of
6163      //  each other vertically.
62-       const  vec4  in_texel =  texelFetch(t_in, ivec3 (x, y, pos.z), 0 );
63-       sum =  fma(in_texel, texelFetch(t_kernel, ivec2 (kx, pos.z), 0 ), sum);
64+       const  vec4  in_texel =  texelFetch(t_in, u16vec3 (x, y, pos.z), 0 );
65+       sum =  fma(in_texel, texelFetch(t_kernel, u16vec2 (kx, pos.z), 0 ), sum);
6466      kx++ ;
6567    }
6668  }
0 commit comments