|
12 | 12 |
|
13 | 13 | #define PRECISION ${PRECISION} |
14 | 14 |
|
15 | | -#define VEC4_T ${texel_type(DTYPE)} |
| 15 | +$if DTYPE == "half": |
| 16 | + #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require |
| 17 | + #define VEC4_T f16vec4 |
| 18 | +$else: |
| 19 | + #define VEC4_T ${texel_type(DTYPE)} |
| 20 | + |
16 | 21 |
|
17 | 22 | #define op(X, A, B) ${OPERATOR} |
18 | 23 |
|
@@ -65,72 +70,72 @@ void main() { |
65 | 70 | return; |
66 | 71 | } |
67 | 72 |
|
68 | | - vec4 outputTexel = texelFetch(t_bias, ivec2(threadOutChannel, 0), 0); |
| 73 | + VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0)); |
69 | 74 |
|
70 | | - vec4 inputVec; |
71 | | - vec4 weight1OutputChannelPacked; |
72 | | - vec4 weight2OutputChannelPacked; |
73 | | - vec4 weight3OutputChannelPacked; |
74 | | - vec4 weight4OutputChannelPacked; |
| 75 | + VEC4_T inputVec; |
| 76 | + VEC4_T weight1OutputChannelPacked; |
| 77 | + VEC4_T weight2OutputChannelPacked; |
| 78 | + VEC4_T weight3OutputChannelPacked; |
| 79 | + VEC4_T weight4OutputChannelPacked; |
75 | 80 |
|
76 | 81 | // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions |
77 | 82 | // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute |
78 | 83 | for (int inputC = 0; inputC < inputChannel; inputC += 1) { |
79 | 84 |
|
80 | | - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); |
| 85 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
81 | 86 |
|
82 | | - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); |
83 | | - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); |
84 | | - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); |
85 | | - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); |
| 87 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 88 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 89 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 90 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
86 | 91 |
|
87 | | - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
88 | | - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
89 | | - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
90 | | - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 92 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 93 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 94 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 95 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
91 | 96 |
|
92 | 97 | inputC += 1; |
93 | 98 |
|
94 | | - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); |
| 99 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
95 | 100 |
|
96 | | - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); |
97 | | - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); |
98 | | - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); |
99 | | - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); |
| 101 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 102 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 103 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 104 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
100 | 105 |
|
101 | | - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
102 | | - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
103 | | - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
104 | | - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 106 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 107 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 108 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 109 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
105 | 110 |
|
106 | 111 | inputC += 1; |
107 | 112 |
|
108 | | - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); |
| 113 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
109 | 114 |
|
110 | | - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); |
111 | | - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); |
112 | | - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); |
113 | | - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); |
| 115 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 116 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 117 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 118 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
114 | 119 |
|
115 | | - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
116 | | - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
117 | | - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
118 | | - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 120 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 121 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 122 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 123 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
119 | 124 |
|
120 | 125 | inputC += 1; |
121 | 126 |
|
122 | | - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); |
| 127 | + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); |
123 | 128 |
|
124 | | - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); |
125 | | - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); |
126 | | - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); |
127 | | - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); |
| 129 | + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); |
| 130 | + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); |
| 131 | + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); |
| 132 | + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); |
128 | 133 |
|
129 | | - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
130 | | - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
131 | | - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
132 | | - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
| 134 | + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); |
| 135 | + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); |
| 136 | + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); |
| 137 | + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); |
133 | 138 | } |
134 | 139 |
|
135 | | - imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(outputTexel, out_min, out_max)); |
| 140 | + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); |
136 | 141 | } |
0 commit comments