Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ void main() {
return;
}

T mat1[TILE_ROWS][4];
VEC4_T qmat2[4][TILE_TXCOLS];
VEC4_T sums[TILE_ROWS][TILE_TXCOLS];

VEC4_T scales[TILE_TXCOLS];
Expand All @@ -86,12 +84,35 @@ void main() {
for (uint16_t pos = uint16_t(0), txpos = uint16_t(0);
pos < uint16_t(in_sizes.x);
pos += uint16_t(4), txpos += uint16_t(1)) {

T mat1[TILE_ROWS][4];

$if IN_STORAGE == "buffer":
uint in_row_txstride = div4(in_sizes.x);

// Preload input tensor
for (int i = 0; i < TILE_ROWS; i++) {
$if IN_STORAGE == "buffer":
VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos];
mat1[i][0] = tmp.x;
mat1[i][1] = tmp.y;
mat1[i][2] = tmp.z;
mat1[i][3] = tmp.w;
$else:
VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0));
mat1[i][0] = tmp.x;
mat1[i][1] = tmp.y;
mat1[i][2] = tmp.z;
mat1[i][3] = tmp.w;
}

$if WEIGHT_STORAGE == "buffer":
uint qmat2_bufi;
uint weight_row_txstride = div4(weight_sizes.x);

// Preload weight tensor
for (int r = 0; r < 4; r++) {
VEC4_T qmat2[TILE_TXCOLS];
$if QUANT_NBITS == 4:
$for c in range(0, TILE_TXCOLS, 2):
$if WEIGHT_STORAGE == "buffer":
Expand All @@ -101,44 +122,21 @@ void main() {
const uvec4 packed_weight_tex = texelFetch(
t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0);

qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0);
qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
qmat2[${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0);
qmat2[${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
$else:
$for c in range(TILE_TXCOLS):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}];
qmat2[${c}] = t_weight[qmat2_bufi + ${c}];
$else:
qmat2[r][${c}] = VEC4_T(
qmat2[${c}] = VEC4_T(
texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0));
}

$if IN_STORAGE == "buffer":
uint in_row_txstride = div4(in_sizes.x);

// Preload input tensor
for (int i = 0; i < TILE_ROWS; i++) {
$if IN_STORAGE == "buffer":
VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos];
mat1[i][0] = tmp.x;
mat1[i][1] = tmp.y;
mat1[i][2] = tmp.z;
mat1[i][3] = tmp.w;
$else:
VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0));
mat1[i][0] = tmp.x;
mat1[i][1] = tmp.y;
mat1[i][2] = tmp.z;
mat1[i][3] = tmp.w;
}

// Accumulate output
for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
sums[r][${c}] += mat1[r][0] * qmat2[0][${c}] +
mat1[r][1] * qmat2[1][${c}] +
mat1[r][2] * qmat2[2][${c}] +
mat1[r][3] * qmat2[3][${c}];
for (int tr = 0; tr < TILE_ROWS; ++tr) {
$for c in range(TILE_TXCOLS):
sums[tr][${c}] += qmat2[${c}] * mat1[tr][r];
}
}
}

Expand Down
Loading