diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index db627681a3a..204352656c9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -64,6 +64,12 @@ void main() { T sums[TILE_ROWS * TILE_TXCOLS * 4]; + $if QUANT_NBITS == 4: + // accumulate mat1 elements sum so -8 bias can be applied using it later. + T mat1_accum[TILE_ROWS]; + $for r in range(TILE_ROWS): + mat1_accum[${r}] = T(0.0); + for (int r = 0; r < TILE_ROWS; ++r) { $for c in range(TILE_TXCOLS): $for j in range(4): @@ -86,6 +92,10 @@ void main() { VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0)); $for j in range(4): mat1[i * 4 + ${j}] = mat1_vec4[${j}]; + + $if QUANT_NBITS == 4: + // Accumulate mat1 element sum, this will be multiplied with -8 later for converting 4 bit data to a signed number. + mat1_accum[i] += mat1[i * 4 + 0] + mat1[i * 4 + 1] + mat1[i * 4 + 2] + mat1[i * 4 + 3]; } $if WEIGHT_STORAGE == "buffer": @@ -109,13 +119,13 @@ void main() { packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); - qmat2_vec4 = (VEC4_T(packed_weight_tex >> 4) - 8.0); + qmat2_vec4 = VEC4_T(packed_weight_tex >> 4); qmat2[${c} * 4 * TILE_TXCOLS + 0] = qmat2_vec4.x; qmat2[${c} * 4 * TILE_TXCOLS + 1] = qmat2_vec4.y; qmat2[${c} * 4 * TILE_TXCOLS + 2] = qmat2_vec4.z; qmat2[${c} * 4 * TILE_TXCOLS + 3] = qmat2_vec4.w; - qmat2_vec4 = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); + qmat2_vec4 = VEC4_T(packed_weight_tex & 0x0F); qmat2[${c} * 4 * TILE_TXCOLS + 4] = qmat2_vec4.x; qmat2[${c} * 4 * TILE_TXCOLS + 5] = qmat2_vec4.y; qmat2[${c} * 4 * TILE_TXCOLS + 6] = qmat2_vec4.z; @@ -156,10 +166,16 @@ void main() { for (int r = 0; r < TILE_ROWS; ++r) { VEC4_T scaled_sums; $for c in range(TILE_TXCOLS): - scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x; - scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y; - scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z; - scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w; + $if QUANT_NBITS == 4: + scaled_sums.x = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] + mat1_accum[r] * -8.0) * scales[${c}].x; + scaled_sums.y = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] + mat1_accum[r] * -8.0) * scales[${c}].y; + scaled_sums.z = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] + mat1_accum[r] * -8.0) * scales[${c}].z; + scaled_sums.w = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] + mat1_accum[r] * -8.0) * scales[${c}].w; + $else: + scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x; + scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y; + scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z; + scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w; $if OUT_STORAGE == "buffer": if (out_row + r < out_sizes.y) {