|
12 | 12 |
|
13 | 13 | #define PRECISION ${PRECISION} |
14 | 14 |
|
15 | | -#define VEC4_T ${texel_type(DTYPE)} |
| 15 | +$if DTYPE == "half": |
| 16 | + #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require |
| 17 | + #define VEC4_T f16vec4 |
| 18 | +$else: |
| 19 | + #define VEC4_T ${texel_type(DTYPE)} |
16 | 20 |
|
17 | | -#define TILE_SIZE_X uint16_t(${TILE_SIZE_X}) |
18 | | -#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y}) |
19 | 21 |
|
20 | 22 | #define op(X, A, B) ${OPERATOR} |
21 | 23 |
|
@@ -50,119 +52,90 @@ ${layout_declare_spec_const(C, "int", "ngroups", "1")} |
50 | 52 | * size is only 1x1, making it easier to re-use loaded texels from t_kernel. |
51 | 53 | */ |
52 | 54 | void main() { |
53 | | - const int out_limits_scaled[2] = |
54 | | - {(out_limits.x + (TILE_SIZE_X - 1)) / TILE_SIZE_X, |
55 | | - (out_limits.y + (TILE_SIZE_Y - 1)) / TILE_SIZE_Y}; |
56 | 55 |
|
57 | | - const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]); |
58 | | - const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x}; |
59 | | - const int out_pos_z = int(gl_GlobalInvocationID.y); |
| 56 | + int inputAndOutputWidth = out_limits.x; |
| 57 | + int inputAndOutputHeight = out_limits.y; |
| 58 | + int outputChannel = out_limits.z*4; |
60 | 59 |
|
61 | | - // If the top left position is out of bounds, then this invocation will have |
62 | | - // no work to do. |
63 | | - if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) { |
| 60 | + // Divided by 4 because the input channels are packed |
| 61 | + int inputChannel = in_group_size/4; |
| 62 | + |
| 63 | + int threadHW = int(gl_GlobalInvocationID.x); |
| 64 | + int threadOutChannel = int(gl_GlobalInvocationID.y); |
| 65 | + |
| 66 | + int xIdx = threadHW % inputAndOutputWidth; |
| 67 | + int yIdx = threadHW / inputAndOutputWidth; |
| 68 | + |
| 69 | + if (threadHW >= inputAndOutputWidth * inputAndOutputHeight && threadOutChannel >= outputChannel) { |
64 | 70 | return; |
65 | 71 | } |
66 | 72 |
|
67 | | - // Output position for TILE_SIZE = 2 |
68 | | - // +--------+--------+ |
69 | | - // | pos[0] | pos[1] | |
70 | | - // +--------+--------+ |
71 | | - // | pos[2] | pos[3] | |
72 | | - // +--------+--------+ |
73 | | - uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; |
74 | | - for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) { |
75 | | - for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) { |
76 | | - pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x; |
77 | | - pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y; |
78 | | - i++; |
79 | | - } |
80 | | - } |
| 73 | + VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0)); |
81 | 74 |
|
82 | | - // Final output array where each element is a tensor value. |
83 | | - // Tuple of consecutive 4 elements represents a single output texel. |
84 | | - float sum[TILE_SIZE_X * TILE_SIZE_Y * 4]; |
| 75 | + VEC4_T inputVec; |
| 76 | + VEC4_T weight1OutputChannelPacked; |
| 77 | + VEC4_T weight2OutputChannelPacked; |
| 78 | + VEC4_T weight3OutputChannelPacked; |
| 79 | + VEC4_T weight4OutputChannelPacked; |
85 | 80 |
|
86 | | - // Initialize the output array with the bias value |
87 | | - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i++) { |
88 | | - sum[i] = 0; |
89 | | - } |
| 81 | + // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions |
| 82 | + // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute |
| 83 | + for (int inputC = 0; inputC < inputChannel; inputC += 1) { |
90 | 84 |
|
91 | | - int z4 = 0; |
92 | | - // Since the kernel is 1x1, we only have to loop over the depth dimension. |
93 | | - for (int z = 0; z < in_group_size; z += 4, ++z4) { |
94 | | - // During prepacking, the weight tensor has been permuted so that the |
95 | | - // channel (IC) dim is along the x-axis, and the batch (OC) dim is along |
96 | | - // the z-axis. |
97 | | - float kernel_values[4 * 4]; // 4 channels, 4 elements per channel |
98 | | - |
99 | | - // Load kernel values from texels to array |
100 | | - [[unroll]] for (int i = 0; i < 4; ++i) { |
101 | | - const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0); |
102 | | - kernel_values[i * 4 + 0] = k_tex.x; |
103 | | - kernel_values[i * 4 + 1] = k_tex.y; |
104 | | - kernel_values[i * 4 + 2] = k_tex.z; |
105 | | - kernel_values[i * 4 + 3] = k_tex.w; |
106 | | - } |
107 | | - |
108 | | - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { |
109 | | - const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0); |
110 | | - // Load the input texel into an array |
111 | | - float tex_values[4]; |
112 | | - tex_values[0] = in_tex.x; |
113 | | - tex_values[1] = in_tex.y; |
114 | | - tex_values[2] = in_tex.z; |
115 | | - tex_values[3] = in_tex.w; |
116 | | - |
117 | | - // For 2x2 tile size algorithm works as follows. |
118 | | - // To explain the calculations below, the contents of one in_tex and the |
119 | | - // group of 4 texels loaded from t_kernel are shown: |
120 | | - // |
121 | | - // in_tex t_kernel |
122 | | - // -x-> ---x---> |
123 | | - // +---+ +----+----+----+----+ |
124 | | - // ^ | w | ^ | D0 | D1 | D2 | D3 | |
125 | | - // | +---+ | +----+----+----+----+ |
126 | | - // | | z | | | C0 | C1 | C2 | C3 | |
127 | | - // z +---+ z +----+----+----+----+ |
128 | | - // | | y | | | B0 | B2 | B2 | B3 | |
129 | | - // | +---+ | +----+----+----+----+ |
130 | | - // | x | | A0 | A1 | A2 | A3 | |
131 | | - // +---+ +----+----+----+----+ |
132 | | - // |
133 | | - // In the t_kernel graphic, cells sharing the same letter are from |
134 | | - // the same batch/output channel index, and the number denotes a unique |
135 | | - // channel index. To calculate the output texel, the following |
136 | | - // calculation is performed: |
137 | | - // |
138 | | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ |
139 | | - // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | |
140 | | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ |
141 | | - // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | |
142 | | - // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ |
143 | | - // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | |
144 | | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ |
145 | | - // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | |
146 | | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ |
147 | | - // |
148 | | - // which is what is expressed in the following calculations. This is done |
149 | | - // for each output position. |
150 | | - for (int j = 0; j < 4; ++j) { |
151 | | - sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j]; |
152 | | - sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j]; |
153 | | - sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j]; |
154 | | - sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j]; |
155 | | - } |
156 | | - } |
157 | | - } |
| 85 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
| 86 | + |
| 87 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 88 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 89 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 90 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
| 91 | + |
| 92 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 93 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 94 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 95 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 96 | + |
| 97 | + inputC += 1; |
| 98 | + |
| 99 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
158 | 100 |
|
159 | | - const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0); |
| 101 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 102 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 103 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 104 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
160 | 105 |
|
161 | | - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { |
162 | | - const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z); |
163 | | - if (all(lessThan(pos_l.xy, out_limits.xy))) { |
164 | | - const vec4 out_sum = vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]); |
165 | | - imageStore(t_out, pos_l, op(out_sum + bias, out_min, out_max)); |
166 | | - } |
| 106 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 107 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 108 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 109 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 110 | + |
| 111 | + inputC += 1; |
| 112 | + |
| 113 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
| 114 | + |
| 115 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 116 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 117 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 118 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
| 119 | + |
| 120 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 121 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 122 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 123 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 124 | + |
| 125 | + inputC += 1; |
| 126 | + |
| 127 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
| 128 | + |
| 129 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 130 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 131 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 132 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
| 133 | + |
| 134 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 135 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 136 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 137 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
167 | 138 | } |
| 139 | + |
| 140 | + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); |
168 | 141 | } |
0 commit comments