Skip to content

Commit c1910fe

Browse files
committed
Add fp16 to conv2d pw s1p0
1 parent 7f4345d commit c1910fe

File tree

1 file changed

+49
-44
lines changed

1 file changed

+49
-44
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
#define PRECISION ${PRECISION}
1414

15-
#define VEC4_T ${texel_type(DTYPE)}
15+
$if DTYPE == "half":
16+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
17+
#define VEC4_T f16vec4
18+
$else:
19+
#define VEC4_T ${texel_type(DTYPE)}
20+
1621

1722
#define op(X, A, B) ${OPERATOR}
1823

@@ -65,72 +70,72 @@ void main() {
6570
return;
6671
}
6772

68-
vec4 outputTexel = texelFetch(t_bias, ivec2(threadOutChannel, 0), 0);
73+
VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0));
6974

70-
vec4 inputVec;
71-
vec4 weight1OutputChannelPacked;
72-
vec4 weight2OutputChannelPacked;
73-
vec4 weight3OutputChannelPacked;
74-
vec4 weight4OutputChannelPacked;
75+
VEC4_T inputVec;
76+
VEC4_T weight1OutputChannelPacked;
77+
VEC4_T weight2OutputChannelPacked;
78+
VEC4_T weight3OutputChannelPacked;
79+
VEC4_T weight4OutputChannelPacked;
7580

7681
// By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions
7782
// and enables the compiler to rearrange instructions for more efficient memory retrieval and compute
7883
for (int inputC = 0; inputC < inputChannel; inputC += 1) {
7984

80-
inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0);
85+
inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0));
8186

82-
weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0);
83-
weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0);
84-
weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0);
85-
weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0);
87+
weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0));
88+
weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0));
89+
weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0));
90+
weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0));
8691

87-
outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
88-
outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
89-
outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
90-
outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
92+
outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
93+
outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
94+
outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
95+
outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
9196

9297
inputC += 1;
9398

94-
inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0);
99+
inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0));
95100

96-
weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0);
97-
weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0);
98-
weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0);
99-
weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0);
101+
weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0));
102+
weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0));
103+
weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0));
104+
weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0));
100105

101-
outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
102-
outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
103-
outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
104-
outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
106+
outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
107+
outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
108+
outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
109+
outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
105110

106111
inputC += 1;
107112

108-
inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0);
113+
inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0));
109114

110-
weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0);
111-
weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0);
112-
weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0);
113-
weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0);
115+
weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0));
116+
weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0));
117+
weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0));
118+
weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0));
114119

115-
outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
116-
outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
117-
outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
118-
outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
120+
outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
121+
outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
122+
outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
123+
outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
119124

120125
inputC += 1;
121126

122-
inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0);
127+
inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0));
123128

124-
weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0);
125-
weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0);
126-
weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0);
127-
weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0);
129+
weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0));
130+
weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0));
131+
weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0));
132+
weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0));
128133

129-
outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
130-
outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
131-
outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
132-
outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
134+
outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0]));
135+
outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1]));
136+
outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2]));
137+
outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3]));
133138
}
134139

135-
imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(outputTexel, out_min, out_max));
140+
imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max));
136141
}

0 commit comments

Comments
 (0)