Skip to content
16 changes: 9 additions & 7 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

#version 450 core

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

#define TILE_SIZE_X ${TILE_SIZE_X}
#define TILE_SIZE_Y ${TILE_SIZE_Y}
#define TILE_SIZE_X uint16_t(${TILE_SIZE_X})
#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y})

#define op(X, A, B) ${OPERATOR}

Expand Down Expand Up @@ -63,11 +65,11 @@ void main() {
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
int pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
for (int x = 0; x < TILE_SIZE_X; ++x) {
pos[i * 2] = out_pos[0] * TILE_SIZE_X + x;
pos[i * 2 + 1] = out_pos[1] * TILE_SIZE_Y + y;
uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2];
for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) {
for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) {
pos[i * 2] = uint16_t(out_pos[0]) * TILE_SIZE_X + x;
pos[i * 2 + 1] = uint16_t(out_pos[1]) * TILE_SIZE_Y + y;
i++;
}
}
Expand Down
Loading