diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index 88b054e2cb2..d2e6b4688eb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -62,36 +62,30 @@ void main() { return; } - VEC4_T sums[TILE_ROWS][TILE_TXCOLS]; + T sums[TILE_ROWS * TILE_TXCOLS * 4]; for (int r = 0; r < TILE_ROWS; ++r) { $for c in range(TILE_TXCOLS): - sums[r][${c}] = VEC4_T(0.0); + $for j in range(4): + sums[r * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] = T(0.0); } + const int in_row_txstride = div4(in_sizes.x); + for (int pos = 0, txpos = 0; - pos < in_sizes.x; + txpos < in_row_txstride; pos += 4, txpos += 1) { - T mat1[TILE_ROWS][4]; - - $if IN_STORAGE == "buffer": - uint in_row_txstride = div4(in_sizes.x); + T mat1[TILE_ROWS * 4]; // Preload input tensor for (int i = 0; i < TILE_ROWS; i++) { $if IN_STORAGE == "buffer": - VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos]; - mat1[i][0] = tmp.x; - mat1[i][1] = tmp.y; - mat1[i][2] = tmp.z; - mat1[i][3] = tmp.w; + VEC4_T mat1_vec4 = t_in[(out_row + i) * in_row_txstride + txpos]; $else: - VEC4_T tmp = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0)); - mat1[i][0] = tmp.x; - mat1[i][1] = tmp.y; - mat1[i][2] = tmp.z; - mat1[i][3] = tmp.w; + VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0)); + $for j in range(4): + mat1[i * 4 + ${j}] = mat1_vec4[${j}]; } $if WEIGHT_STORAGE == "buffer": @@ -100,7 +94,9 @@ void main() { // Preload weight tensor for (int r = 0; r < 4; r++) { - VEC4_T qmat2[TILE_TXCOLS]; + T qmat2[TILE_TXCOLS * 4]; + VEC4_T qmat2_vec4; + $if QUANT_NBITS == 4: $if WEIGHT_STORAGE == "buffer": u8vec4 packed_weight_tex; @@ -115,20 +111,31 @@ void main() { packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); - qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4) - 8.0); - qmat2[${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); + qmat2_vec4 = (VEC4_T(packed_weight_tex >> 4) - 8.0); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = qmat2_vec4.x; + qmat2[${c} * 4 * TILE_TXCOLS + 1] = qmat2_vec4.y; + qmat2[${c} * 4 * TILE_TXCOLS + 2] = qmat2_vec4.z; + qmat2[${c} * 4 * TILE_TXCOLS + 3] = qmat2_vec4.w; + + qmat2_vec4 = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); + qmat2[${c} * 4 * TILE_TXCOLS + 4] = qmat2_vec4.x; + qmat2[${c} * 4 * TILE_TXCOLS + 5] = qmat2_vec4.y; + qmat2[${c} * 4 * TILE_TXCOLS + 6] = qmat2_vec4.z; + qmat2[${c} * 4 * TILE_TXCOLS + 7] = qmat2_vec4.w; $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; - qmat2[${c}] = t_weight[qmat2_bufi + ${c}]; + qmat2_vec4 = t_weight[qmat2_bufi + ${c}]; $else: - qmat2[${c}] = VEC4_T( - texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0)); + qmat2_vec4 = VEC4_T(texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0)); + $for j in range(4): + qmat2[${c} * 4 + ${j}] = qmat2_vec4[${j}]; for (int tr = 0; tr < TILE_ROWS; ++tr) { $for c in range(TILE_TXCOLS): - sums[tr][${c}] += qmat2[${c}] * mat1[tr][r]; + $for j in range(4): + sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r]; } } } @@ -147,16 +154,22 @@ void main() { uint out_row_txstride = div4(out_sizes.x); for (int r = 0; r < TILE_ROWS; ++r) { + VEC4_T scaled_sums; $for c in range(TILE_TXCOLS): + scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x; + scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y; + scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z; + scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w; + $if OUT_STORAGE == "buffer": if (out_row + r < out_sizes.y) { out_bufi = (out_row + r) * out_row_txstride + out_txcol; - t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}]; + t_out[out_bufi + ${c}] = scaled_sums; } $else: imageStore( t_out, ivec3(out_txcol + ${c}, out_row + r, 0), - sums[r][${c}] * scales[${c}]); + scaled_sums); } }