|
12 | 12 |
|
13 | 13 | #define PRECISION ${PRECISION}
|
14 | 14 |
|
15 |
| -#define VEC4_T ${texel_type(DTYPE)} |
| 15 | +$if DTYPE == "half": |
| 16 | + #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require |
| 17 | + #define VEC4_T f16vec4 |
| 18 | +$else: |
| 19 | + #define VEC4_T ${texel_type(DTYPE)} |
16 | 20 |
|
17 |
| -#define TILE_SIZE_X uint16_t(${TILE_SIZE_X}) |
18 |
| -#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y}) |
19 | 21 |
|
20 | 22 | #define op(X, A, B) ${OPERATOR}
|
21 | 23 |
|
@@ -50,119 +52,90 @@ ${layout_declare_spec_const(C, "int", "ngroups", "1")}
|
50 | 52 | * size is only 1x1, making it easier to re-use loaded texels from t_kernel.
|
51 | 53 | */
|
52 | 54 | void main() {
|
53 |
| - const int out_limits_scaled[2] = |
54 |
| - {(out_limits.x + (TILE_SIZE_X - 1)) / TILE_SIZE_X, |
55 |
| - (out_limits.y + (TILE_SIZE_Y - 1)) / TILE_SIZE_Y}; |
56 | 55 |
|
57 |
| - const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]); |
58 |
| - const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x}; |
59 |
| - const int out_pos_z = int(gl_GlobalInvocationID.y); |
| 56 | + int inputAndOutputWidth = out_limits.x; |
| 57 | + int inputAndOutputHeight = out_limits.y; |
| 58 | + int outputChannel = out_limits.z*4; |
60 | 59 |
|
61 |
| - // If the top left position is out of bounds, then this invocation will have |
62 |
| - // no work to do. |
63 |
| - if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) { |
| 60 | + // Divided by 4 because the input channels are packed |
| 61 | + int inputChannel = in_group_size/4; |
| 62 | + |
| 63 | + int threadHW = int(gl_GlobalInvocationID.x); |
| 64 | + int threadOutChannel = int(gl_GlobalInvocationID.y); |
| 65 | + |
| 66 | + int xIdx = threadHW % inputAndOutputWidth; |
| 67 | + int yIdx = threadHW / inputAndOutputWidth; |
| 68 | + |
| 69 | + if (threadHW >= inputAndOutputWidth * inputAndOutputHeight && threadOutChannel >= outputChannel) { |
64 | 70 | return;
|
65 | 71 | }
|
66 | 72 |
|
67 |
| - // Output position for TILE_SIZE = 2 |
68 |
| - // +--------+--------+ |
69 |
| - // | pos[0] | pos[1] | |
70 |
| - // +--------+--------+ |
71 |
| - // | pos[2] | pos[3] | |
72 |
| - // +--------+--------+ |
73 |
| - uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; |
74 |
| - for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) { |
75 |
| - for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) { |
76 |
| - pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x; |
77 |
| - pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y; |
78 |
| - i++; |
79 |
| - } |
80 |
| - } |
| 73 | + VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0)); |
81 | 74 |
|
82 |
| - // Final output array where each element is a tensor value. |
83 |
| - // Tuple of consecutive 4 elements represents a single output texel. |
84 |
| - float sum[TILE_SIZE_X * TILE_SIZE_Y * 4]; |
| 75 | + VEC4_T inputVec; |
| 76 | + VEC4_T weight1OutputChannelPacked; |
| 77 | + VEC4_T weight2OutputChannelPacked; |
| 78 | + VEC4_T weight3OutputChannelPacked; |
| 79 | + VEC4_T weight4OutputChannelPacked; |
85 | 80 |
|
86 |
| - // Initialize the output array with the bias value |
87 |
| - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i++) { |
88 |
| - sum[i] = 0; |
89 |
| - } |
| 81 | + // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions |
| 82 | + // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute |
| 83 | + for (int inputC = 0; inputC < inputChannel; inputC += 1) { |
90 | 84 |
|
91 |
| - int z4 = 0; |
92 |
| - // Since the kernel is 1x1, we only have to loop over the depth dimension. |
93 |
| - for (int z = 0; z < in_group_size; z += 4, ++z4) { |
94 |
| - // During prepacking, the weight tensor has been permuted so that the |
95 |
| - // channel (IC) dim is along the x-axis, and the batch (OC) dim is along |
96 |
| - // the z-axis. |
97 |
| - float kernel_values[4 * 4]; // 4 channels, 4 elements per channel |
98 |
| - |
99 |
| - // Load kernel values from texels to array |
100 |
| - [[unroll]] for (int i = 0; i < 4; ++i) { |
101 |
| - const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0); |
102 |
| - kernel_values[i * 4 + 0] = k_tex.x; |
103 |
| - kernel_values[i * 4 + 1] = k_tex.y; |
104 |
| - kernel_values[i * 4 + 2] = k_tex.z; |
105 |
| - kernel_values[i * 4 + 3] = k_tex.w; |
106 |
| - } |
107 |
| - |
108 |
| - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { |
109 |
| - const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0); |
110 |
| - // Load the input texel into an array |
111 |
| - float tex_values[4]; |
112 |
| - tex_values[0] = in_tex.x; |
113 |
| - tex_values[1] = in_tex.y; |
114 |
| - tex_values[2] = in_tex.z; |
115 |
| - tex_values[3] = in_tex.w; |
116 |
| - |
117 |
| - // For 2x2 tile size algorithm works as follows. |
118 |
| - // To explain the calculations below, the contents of one in_tex and the |
119 |
| - // group of 4 texels loaded from t_kernel are shown: |
120 |
| - // |
121 |
| - // in_tex t_kernel |
122 |
| - // -x-> ---x---> |
123 |
| - // +---+ +----+----+----+----+ |
124 |
| - // ^ | w | ^ | D0 | D1 | D2 | D3 | |
125 |
| - // | +---+ | +----+----+----+----+ |
126 |
| - // | | z | | | C0 | C1 | C2 | C3 | |
127 |
| - // z +---+ z +----+----+----+----+ |
128 |
| - // | | y | | | B0 | B2 | B2 | B3 | |
129 |
| - // | +---+ | +----+----+----+----+ |
130 |
| - // | x | | A0 | A1 | A2 | A3 | |
131 |
| - // +---+ +----+----+----+----+ |
132 |
| - // |
133 |
| - // In the t_kernel graphic, cells sharing the same letter are from |
134 |
| - // the same batch/output channel index, and the number denotes a unique |
135 |
| - // channel index. To calculate the output texel, the following |
136 |
| - // calculation is performed: |
137 |
| - // |
138 |
| - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ |
139 |
| - // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | |
140 |
| - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ |
141 |
| - // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | |
142 |
| - // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ |
143 |
| - // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | |
144 |
| - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ |
145 |
| - // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | |
146 |
| - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ |
147 |
| - // |
148 |
| - // which is what is expressed in the following calculations. This is done |
149 |
| - // for each output position. |
150 |
| - for (int j = 0; j < 4; ++j) { |
151 |
| - sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j]; |
152 |
| - sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j]; |
153 |
| - sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j]; |
154 |
| - sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j]; |
155 |
| - } |
156 |
| - } |
157 |
| - } |
| 85 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
| 86 | + |
| 87 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 88 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 89 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 90 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
| 91 | + |
| 92 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 93 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 94 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 95 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 96 | + |
| 97 | + inputC += 1; |
| 98 | + |
| 99 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
158 | 100 |
|
159 |
| - const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0); |
| 101 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 102 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 103 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 104 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
160 | 105 |
|
161 |
| - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { |
162 |
| - const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z); |
163 |
| - if (all(lessThan(pos_l.xy, out_limits.xy))) { |
164 |
| - const vec4 out_sum = vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]); |
165 |
| - imageStore(t_out, pos_l, op(out_sum + bias, out_min, out_max)); |
166 |
| - } |
| 106 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 107 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 108 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 109 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 110 | + |
| 111 | + inputC += 1; |
| 112 | + |
| 113 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
| 114 | + |
| 115 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 116 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 117 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 118 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
| 119 | + |
| 120 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 121 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 122 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 123 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 124 | + |
| 125 | + inputC += 1; |
| 126 | + |
| 127 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
| 128 | + |
| 129 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 130 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 131 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 132 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
| 133 | + |
| 134 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 135 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 136 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 137 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
167 | 138 | }
|
| 139 | + |
| 140 | + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); |
168 | 141 | }
|
0 commit comments