Skip to content
14 changes: 6 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,9 @@ void main() {
// Tuple of consecutive 4 elements represents a single output texel.
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];

const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0);

// Initialize the output array with the bias value
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
sum[i] = bias.x;
sum[i + 1] = bias.y;
sum[i + 2] = bias.z;
sum[i + 3] = bias.w;
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i++) {
sum[i] = 0;
}

int z4 = 0;
Expand Down Expand Up @@ -157,10 +152,13 @@ void main() {
}
}

const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0);

for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z);
if (all(lessThan(pos_l.xy, out_limits.xy))) {
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
const vec4 out_sum = vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]);
imageStore(t_out, pos_l, op(out_sum + bias, out_min, out_max));
}
}
}
Loading