@@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3030
3131layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3232
33+ ${layout_declare_spec_const(C, "int ", "ngroups", "1 ")}
34+
3335/*
3436 * Computes a 2D convolution. Each shader invocation calculates the output at
3537 * a single output location.
@@ -74,7 +76,18 @@ void main() {
7476 // Perform the convolution by iterating over the overlay region.
7577 VEC4_T sum = texelFetch(t_bias, ivec2 (pos.z, 0 ), 0 );
7678 const int ic4 = in_group_size / 4 ;
77- for (int z4 = 0 ; z4 < ic4; ++ z4, kstart.x += kernel_size.x * 4 ) {
79+
80+ int z_start = 0 ;
81+ int z_end = ic4;
82+ if (ngroups > 1 ) {
83+ const int group_size = (out_limits.z) / ngroups;
84+ const int group_idx = pos.z / group_size;
85+
86+ z_start = ic4 * group_idx;
87+ z_end = z_start + ic4;
88+ }
89+
90+ for (int z4 = z_start; z4 < z_end; ++ z4, kstart.x += kernel_size.x * 4 ) {
7891 for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ ky) {
7992 for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4 ) {
8093 const VEC4_T in_texel = texelFetch(t_in, ivec3 (x, y, z4), 0 );
0 commit comments