Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#define VEC4_T ${texel_type(DTYPE)}

#define T ${texel_component_type(DTYPE)}

#define op(X, A, B) ${OPERATOR}

#include "indexing_utils.h"
Expand Down Expand Up @@ -72,13 +74,24 @@ void main() {
kstart.y += pos.z * kernel_size.y;

// Perform the convolution by iterating over the overlay region.
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
T sum[4];
sum[0] = T(0);
sum[1] = T(0);
sum[2] = T(0);
sum[3] = T(0);

const int ic4 = in_group_size / 4;
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) {
for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) {
for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) {
const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0);
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
const VEC4_T in_texel_v = texelFetch(t_in, ivec3(x, y, z4), 0);

T in_texel[4];
in_texel[0] = in_texel_v.x;
in_texel[1] = in_texel_v.y;
in_texel[2] = in_texel_v.z;
in_texel[3] = in_texel_v.w;


// To explain the calculation below, the contents of in_texel and the
// group of 4 texels loaded from t_kernel are shown:
Expand Down Expand Up @@ -112,13 +125,27 @@ void main() {
//
// which is expressed in the following statements.

sum = fma(in_texel.xxxx, texelFetch(t_kernel, ivec2(kxs.x, ky), 0), sum);
sum = fma(in_texel.yyyy, texelFetch(t_kernel, ivec2(kxs.y, ky), 0), sum);
sum = fma(in_texel.zzzz, texelFetch(t_kernel, ivec2(kxs.z, ky), 0), sum);
sum = fma(in_texel.wwww, texelFetch(t_kernel, ivec2(kxs.w, ky), 0), sum);
T k_tex_arr[16];
for (int kc = 0; kc < 4; kc++) {
const VEC4_T k_tex = texelFetch(t_kernel, ivec2(kx + kc, ky), 0);
k_tex_arr[kc * 4 + 0] = k_tex.x;
k_tex_arr[kc * 4 + 1] = k_tex.y;
k_tex_arr[kc * 4 + 2] = k_tex.z;
k_tex_arr[kc * 4 + 3] = k_tex.w;
}

for (int sc = 0; sc < 4; sc++) {
sum[0] = fma(in_texel[sc], k_tex_arr[sc * 4 + 0], sum[0]);
sum[1] = fma(in_texel[sc], k_tex_arr[sc * 4 + 1], sum[1]);
sum[2] = fma(in_texel[sc], k_tex_arr[sc * 4 + 2], sum[2]);
sum[3] = fma(in_texel[sc], k_tex_arr[sc * 4 + 3], sum[3]);
}
}
}
}

imageStore(t_out, pos, op(sum, out_min, out_max));
const VEC4_T bias = texelFetch(t_bias, ivec2(pos.z, 0), 0);
const VEC4_T out_sum = VEC4_T(sum[0], sum[1], sum[2], sum[3]) + bias;

imageStore(t_out, pos, op(out_sum, out_min, out_max));
}
Loading