|
12 | 12 |
|
13 | 13 | #define VEC4_T ${texel_type(DTYPE)} |
14 | 14 |
|
| 15 | +#define T ${texel_component_type(DTYPE)} |
| 16 | + |
15 | 17 | #define op(X, A, B) ${OPERATOR} |
16 | 18 |
|
17 | 19 | #include "indexing_utils.h" |
@@ -72,13 +74,24 @@ void main() { |
72 | 74 | kstart.y += pos.z * kernel_size.y; |
73 | 75 |
|
74 | 76 | // Perform the convolution by iterating over the overlay region. |
75 | | - VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); |
| 77 | + T sum[4]; |
| 78 | + sum[0] = T(0); |
| 79 | + sum[1] = T(0); |
| 80 | + sum[2] = T(0); |
| 81 | + sum[3] = T(0); |
| 82 | + |
76 | 83 | const int ic4 = in_group_size / 4; |
77 | 84 | for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) { |
78 | 85 | for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) { |
79 | 86 | for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) { |
80 | | - const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0); |
81 | | - const ivec4 kxs = kx + ivec4(0, 1, 2, 3); |
| 87 | + const VEC4_T in_texel_v = texelFetch(t_in, ivec3(x, y, z4), 0); |
| 88 | + |
| 89 | + T in_texel[4]; |
| 90 | + in_texel[0] = in_texel_v.x; |
| 91 | + in_texel[1] = in_texel_v.y; |
| 92 | + in_texel[2] = in_texel_v.z; |
| 93 | + in_texel[3] = in_texel_v.w; |
| 94 | + |
82 | 95 |
|
83 | 96 | // To explain the calculation below, the contents of in_texel and the |
84 | 97 | // group of 4 texels loaded from t_kernel are shown: |
@@ -112,13 +125,27 @@ void main() { |
112 | 125 | // |
113 | 126 | // which is expressed in the following statements. |
114 | 127 |
|
115 | | - sum = fma(in_texel.xxxx, texelFetch(t_kernel, ivec2(kxs.x, ky), 0), sum); |
116 | | - sum = fma(in_texel.yyyy, texelFetch(t_kernel, ivec2(kxs.y, ky), 0), sum); |
117 | | - sum = fma(in_texel.zzzz, texelFetch(t_kernel, ivec2(kxs.z, ky), 0), sum); |
118 | | - sum = fma(in_texel.wwww, texelFetch(t_kernel, ivec2(kxs.w, ky), 0), sum); |
| 128 | + T k_tex_arr[16]; |
| 129 | + for (int kc = 0; kc < 4; kc++) { |
| 130 | + const VEC4_T k_tex = texelFetch(t_kernel, ivec2(kx + kc, ky), 0); |
| 131 | + k_tex_arr[kc * 4 + 0] = k_tex.x; |
| 132 | + k_tex_arr[kc * 4 + 1] = k_tex.y; |
| 133 | + k_tex_arr[kc * 4 + 2] = k_tex.z; |
| 134 | + k_tex_arr[kc * 4 + 3] = k_tex.w; |
| 135 | + } |
| 136 | + |
| 137 | + for (int sc = 0; sc < 4; sc++) { |
| 138 | + sum[0] = fma(in_texel[sc], k_tex_arr[sc * 4 + 0], sum[0]); |
| 139 | + sum[1] = fma(in_texel[sc], k_tex_arr[sc * 4 + 1], sum[1]); |
| 140 | + sum[2] = fma(in_texel[sc], k_tex_arr[sc * 4 + 2], sum[2]); |
| 141 | + sum[3] = fma(in_texel[sc], k_tex_arr[sc * 4 + 3], sum[3]); |
| 142 | + } |
119 | 143 | } |
120 | 144 | } |
121 | 145 | } |
122 | 146 |
|
123 | | - imageStore(t_out, pos, op(sum, out_min, out_max)); |
| 147 | + const VEC4_T bias = texelFetch(t_bias, ivec2(pos.z, 0), 0); |
| 148 | + const VEC4_T out_sum = VEC4_T(sum[0], sum[1], sum[2], sum[3]) + bias; |
| 149 | + |
| 150 | + imageStore(t_out, pos, op(out_sum, out_min, out_max)); |
124 | 151 | } |
0 commit comments