Skip to content
23 changes: 10 additions & 13 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -58,32 +58,28 @@ void main() {
return;
}

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (gpos.z >= out_limits.z) {
return;
}

// Output position for TILE_SIZE = 2
// +--------+--------+
// | pos[0] | pos[1] |
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
ivec3 pos[TILE_SIZE_X * TILE_SIZE_Y];
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
for (int x = 0; x < TILE_SIZE_X; ++x) {
pos[i] = ivec3(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y, gpos.z);
pos[i * 2] = gpos.x * TILE_SIZE_X + x;
pos[i * 2 + 1] = gpos.y * TILE_SIZE_Y + y;
i++;
}
}

// Compute the index of the input texture that needs to be loaded for each
// output position. Note that negative indices can be produced indicating that
// the top-left element is in a region added by padding.
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
int ipos[TILE_SIZE_X * TILE_SIZE_Y * 2];
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
ipos[i] = pos[i].xy * stride - padding;
ipos[i * 2] = pos[i * 2] * stride.x - padding.x;
ipos[i * 2 + 1] = pos[i * 2 + 1] * stride.y - padding.y;
}

// Final output array where each element is a tensor value.
Expand Down Expand Up @@ -118,7 +114,7 @@ void main() {
}

for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i * 2], ipos[i * 2 + 1], z4), 0);
// Load the input texel into an array
float tex_values[4];
tex_values[0] = in_tex.x;
Expand Down Expand Up @@ -169,8 +165,9 @@ void main() {
}

for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
if (all(lessThan(pos[i], out_limits.xyz))) {
imageStore(t_out, pos[i], op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], gpos.z);
if (all(lessThan(pos_l, out_limits.xyz))) {
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
}
}
}
Loading