@@ -56,75 +56,63 @@ const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout);
5656//  weight = (out_C, in_C / G, K),
5757//  bias = (out_C,).
5858// 
59- //  This implementation performs out_C shader invocations, where each invocation
59+ //  This implementation performs N x  out_C x out_L  shader invocations, where each invocation
6060//  calculates the rolling kernel of the length dimension for each batch, i.e.,
61- //  computes out_L * N results.
62- // 
63- //  Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
64- //  shader invocations, where each invocation computes 1 result. But that
65- //  performs worse.
61+ //  computes out_L results.
6662void  main() {
6763  const  ivec3  lpos =  ivec3 (gl_GlobalInvocationID);
6864
6965  if  (any (greaterThanEqual (lpos, out_limits))) {
7066    return ;
7167  }
7268
73-   int  in_length =  in_sizes.x;
74-   int  batch_size =  in_sizes.z;
75- 
7669  //  "out_c" is the output's channel index where we write our result.
7770  //  Across shader invocations, this is the only value that varies.
78-   int  out_c =  lpos.y;
79-   VEC4_T bias =  load_texel_lpos(bias_in, ivec3 (out_c, 0 , 0 ), bias_axis_map);
71+   const  int  out_c =  lpos.y;
8072
8173  //  "in_c" tracks the input's channel start index.
8274  //  We iterate over the input group that corresponds to the output group.
83-   int  c_start =  (out_c /  out_group_size) *  in_group_size;
84-   int  c_end =  c_start +  in_group_size;
75+   const  int  c_start =  (out_c /  out_group_size) *  in_group_size;
76+   const  int  c_end =  c_start +  in_group_size;
77+ 
78+   //  "out_l" tracks the output's length index where we write our result.
79+   const  int  out_l =  lpos.x;
80+ 
81+   //  "N" is the batch index
82+   const  int  N =  lpos.z;
8583
8684  //  "in_l" tracks the input's length start index for our input-kernel overlay
8785  //  region.
88-   int  l_start =  - padding;
89-   int  l_end =  in_length +  padding -  dilation *  (kernel_size -  1 );
90- 
91-   //  Since the input/output tensors are channel-packed, which is along the
92-   //  batch dimension, we can batch-read/write four elements at a time.
93-   for  (int  n =  0 ; n <  batch_size; n +=  4 ) {
94-     //  "out_l" tracks the output's length index where we write our result.
95-     int  out_l =  0 ;
96- 
97-     for  (int  in_l =  l_start; in_l <  l_end; in_l +=  stride, ++ out_l) {
98-       VEC4_T sum =  VEC4_T(0 );
99- 
100-       for  (int  in_c =  c_start; in_c <  c_end; ++ in_c) {
101-         //  "k" tracks the kernel's index for our input-kernel computation.
102-         //  It reads out-of-bound zeros, but trying to avoid them complicates
103-         //  for-loop conditions, which results in worse performance.
104- 
105-         //  The weight tensor is channel-packed. It may not be trival choice for
106-         //  performance reason since need to have more data fetch. The reason is
107-         //  for some sequence model, we found that the weight tensor
108-         //  (out_channel, in_channel / group, kernel) often has a large
109-         //  out_channel >> kernel, leading to non-optimal use of memory as the
110-         //  weight tensor gets very deep. As a mitigation, we use channel-packing
111-         //  for the weight tensor, yielding a 75% reduction in weight-tensor
112-         //  memory.
113- 
114-         //  It is possible to further reduce the memory footprint by swapping the
115-         //  dimensions, using x extent for out_channel, and y for kernel.
116-         for  (int  k =  0 ; k <  kernel_size; k +=  1 ) {
117-           const  ivec3  w_lpos =  ivec3 (k, in_c %  in_group_size, out_c /  4 );
118-           const  VEC4_T weight_texel =  load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
119-           VEC4_T weight =  VEC4_T(weight_texel[out_c %  4 ]);
120- 
121-           ivec3  in_pos =  lpos_to_pos(ivec3 (in_l +  k *  dilation, in_c, n /  4 ), in_axis_map);
122-           sum =  fma(weight, load_texel(t_in, in_pos), sum);
123-         }
124-       }
125- 
126-       const  ivec3  out_lpos =  ivec3 (out_l, out_c, n /  4 );
127-       write_texel_lpos(t_out, out_lpos, op(sum +  bias.x, out_min, out_max), out_axis_map);
86+   const  int  in_l =  out_l *  stride -  padding;
87+   VEC4_T sum =  VEC4_T(0 );
88+ 
89+   for  (int  in_c =  c_start; in_c <  c_end; ++ in_c) {
90+     //  "k" tracks the kernel's index for our input-kernel computation.
91+     //  It reads out-of-bound zeros, but trying to avoid them complicates
92+     //  for-loop conditions, which results in worse performance.
93+ 
94+     //  The weight tensor is channel-packed. It may not be trival choice for
95+     //  performance reason since need to have more data fetch. The reason is
96+     //  for some sequence model, we found that the weight tensor
97+     //  (out_channel, in_channel / group, kernel) often has a large
98+     //  out_channel >> kernel, leading to non-optimal use of memory as the
99+     //  weight tensor gets very deep. As a mitigation, we use channel-packing
100+     //  for the weight tensor, yielding a 75% reduction in weight-tensor
101+     //  memory.
102+ 
103+     //  It is possible to further reduce the memory footprint by swapping the
104+     //  dimensions, using x extent for out_channel, and y for kernel.
105+     for  (int  k =  0 ; k <  kernel_size; k++ ) {
106+       const  ivec3  w_lpos =  ivec3 (k, in_c %  in_group_size, out_c /  4 );
107+       const  VEC4_T weight_texel =  load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
108+       VEC4_T weight =  VEC4_T(weight_texel[out_c %  4 ]);
109+ 
110+       const  ivec3  in_pos =  lpos_to_pos(ivec3 (in_l +  k *  dilation, in_c, N), in_axis_map);
111+       sum =  fma(weight, load_texel(t_in, in_pos), sum);
128112    }
129113  }
114+ 
115+   const  VEC4_T bias =  load_texel_lpos(bias_in, ivec3 (out_c, 0 , 0 ), bias_axis_map);
116+   const  ivec3  out_lpos =  ivec3 (out_l, out_c, N);
117+   write_texel_lpos(t_out, out_lpos, op(sum +  bias.x, out_min, out_max), out_axis_map);
130118}
0 commit comments