diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index e63e267a4d7..dd1596cfa35 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -66,8 +66,6 @@ void main() { return; } - T mat1[TILE_ROWS][4]; - VEC4_T qmat2[4][TILE_TXCOLS]; VEC4_T sums[TILE_ROWS][TILE_TXCOLS]; VEC4_T scales[TILE_TXCOLS]; @@ -86,12 +84,35 @@ void main() { for (uint16_t pos = uint16_t(0), txpos = uint16_t(0); pos < uint16_t(in_sizes.x); pos += uint16_t(4), txpos += uint16_t(1)) { + + T mat1[TILE_ROWS][4]; + + $if IN_STORAGE == "buffer": + uint in_row_txstride = div4(in_sizes.x); + + // Preload input tensor + for (int i = 0; i < TILE_ROWS; i++) { + $if IN_STORAGE == "buffer": + VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos]; + mat1[i][0] = tmp.x; + mat1[i][1] = tmp.y; + mat1[i][2] = tmp.z; + mat1[i][3] = tmp.w; + $else: + VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0)); + mat1[i][0] = tmp.x; + mat1[i][1] = tmp.y; + mat1[i][2] = tmp.z; + mat1[i][3] = tmp.w; + } + $if WEIGHT_STORAGE == "buffer": uint qmat2_bufi; uint weight_row_txstride = div4(weight_sizes.x); // Preload weight tensor for (int r = 0; r < 4; r++) { + VEC4_T qmat2[TILE_TXCOLS]; $if QUANT_NBITS == 4: $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": @@ -101,44 +122,21 @@ void main() { const uvec4 packed_weight_tex = texelFetch( t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0); - qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); - qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); + qmat2[${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); + qmat2[${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; - qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}]; + qmat2[${c}] = t_weight[qmat2_bufi + ${c}]; $else: - qmat2[r][${c}] = VEC4_T( + qmat2[${c}] = VEC4_T( texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0)); - } - - $if IN_STORAGE == "buffer": - uint in_row_txstride = div4(in_sizes.x); - // Preload input tensor - for (int i = 0; i < TILE_ROWS; i++) { - $if IN_STORAGE == "buffer": - VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos]; - mat1[i][0] = tmp.x; - mat1[i][1] = tmp.y; - mat1[i][2] = tmp.z; - mat1[i][3] = tmp.w; - $else: - VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0)); - mat1[i][0] = tmp.x; - mat1[i][1] = tmp.y; - mat1[i][2] = tmp.z; - mat1[i][3] = tmp.w; - } - - // Accumulate output - for (int r = 0; r < TILE_ROWS; ++r) { - $for c in range(TILE_TXCOLS): - sums[r][${c}] += mat1[r][0] * qmat2[0][${c}] + - mat1[r][1] * qmat2[1][${c}] + - mat1[r][2] * qmat2[2][${c}] + - mat1[r][3] * qmat2[3][${c}]; + for (int tr = 0; tr < TILE_ROWS; ++tr) { + $for c in range(TILE_TXCOLS): + sums[tr][${c}] += qmat2[${c}] * mat1[tr][r]; + } } }