@@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
30
30
31
31
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
32
32
33
+ ${layout_declare_spec_const(C, "int ", "ngroups", "1 ")}
34
+
33
35
/*
34
36
* Computes a 2D convolution. Each shader invocation calculates the output at
35
37
* a single output location.
@@ -74,7 +76,18 @@ void main() {
74
76
// Perform the convolution by iterating over the overlay region.
75
77
VEC4_T sum = texelFetch(t_bias, ivec2 (pos.z, 0 ), 0 );
76
78
const int ic4 = in_group_size / 4 ;
77
- for (int z4 = 0 ; z4 < ic4; ++ z4, kstart.x += kernel_size.x * 4 ) {
79
+
80
+ int z_start = 0 ;
81
+ int z_end = ic4;
82
+ if (ngroups > 1 ) {
83
+ const int group_size = (out_limits.z) / ngroups;
84
+ const int group_idx = pos.z / group_size;
85
+
86
+ z_start = ic4 * group_idx;
87
+ z_end = z_start + ic4;
88
+ }
89
+
90
+ for (int z4 = z_start; z4 < z_end; ++ z4, kstart.x += kernel_size.x * 4 ) {
78
91
for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ ky) {
79
92
for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4 ) {
80
93
const VEC4_T in_texel = texelFetch(t_in, ivec3 (x, y, z4), 0 );
0 commit comments