Skip to content
19 changes: 10 additions & 9 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
void main() {
const int out_limits_scaled[2] = {out_limits.x + (TILE_SIZE_X - 1) * TILE_SIZE_X, out_limits.y + (TILE_SIZE_Y - 1) * TILE_SIZE_Y};

const int div_by_x = int(gl_GlobalInvocationID.x / out_limits_scaled[0]);
const int out_pos[3] = {int(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x, int(gl_GlobalInvocationID.y)};
const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]);
const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x};
const int out_pos_z = int(gl_GlobalInvocationID.y);

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (out_pos[1] >= out_limits_scaled[1] || out_pos[2] >= out_limits.z) {
if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) {
return;
}

Expand All @@ -68,8 +69,8 @@ void main() {
uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) {
for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) {
pos[i * 2] = uint16_t(out_pos[0]) * TILE_SIZE_X + x;
pos[i * 2 + 1] = uint16_t(out_pos[1]) * TILE_SIZE_Y + y;
pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x;
pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y;
i++;
}
}
Expand All @@ -78,7 +79,7 @@ void main() {
// Tuple of consecutive 4 elements represents a single output texel.
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];

const vec4 bias = texelFetch(t_bias, ivec2(out_pos[2], 0), 0);
const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0);

// Initialize the output array with the bias value
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
Expand All @@ -98,7 +99,7 @@ void main() {

// Load kernel values from texels to array
[[unroll]] for (int i = 0; i < 4; ++i) {
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos[2]), 0);
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0);
kernel_values[i * 4 + 0] = k_tex.x;
kernel_values[i * 4 + 1] = k_tex.y;
kernel_values[i * 4 + 2] = k_tex.z;
Expand Down Expand Up @@ -157,8 +158,8 @@ void main() {
}

for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos[2]);
if (all(lessThan(pos_l, out_limits.xyz))) {
const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z);
if (all(lessThan(pos_l.xy, out_limits.xy))) {
imageStore(t_out, pos_l, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
}
}
Expand Down
Loading