@@ -64,6 +64,12 @@ void main() {
6464
6565 T sums[TILE_ROWS * TILE_TXCOLS * 4 ];
6666
67+ $if QUANT_NBITS == 4 :
68+ // accumulate mat1 elements sums so -8 bias can be applied to it later
69+ T mat1_accum[TILE_ROWS];
70+ $for r in range(TILE_ROWS):
71+ mat1_accum[${r}] = T(0.0 );
72+
6773 for (int r = 0 ; r < TILE_ROWS; ++ r) {
6874 $for c in range(TILE_TXCOLS):
6975 $for j in range(4 ):
@@ -86,6 +92,11 @@ void main() {
8692 VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3 (txpos, out_row + i, 0 ), 0 ));
8793 $for j in range(4 ):
8894 mat1[i * 4 + ${j}] = mat1_vec4[${j}];
95+
96+ $if QUANT_NBITS == 4 :
97+ // Apply -8 * mat1 bias here rather then below, to effectively reduce overall number of math operations performed during runtime.
98+ // Accumulate mat1 element sum, this will be multiplied with -8 later for converting 4 bit data to a signed number.
99+ mat1_accum[i] += mat1[i * 4 + 0 ] + mat1[i * 4 + 1 ] + mat1[i * 4 + 2 ] + mat1[i * 4 + 3 ];
89100 }
90101
91102 $if WEIGHT_STORAGE == "buffer ":
@@ -109,13 +120,13 @@ void main() {
109120 packed_weight_tex = texelFetch(
110121 t_weight, ivec2 (weight_txcol + ${c}, pos + r), 0 );
111122
112- qmat2_vec4 = ( VEC4_T(packed_weight_tex >> 4 ) - 8.0 );
123+ qmat2_vec4 = VEC4_T(packed_weight_tex >> 4 );
113124 qmat2[${c} * 4 * TILE_TXCOLS + 0 ] = qmat2_vec4.x;
114125 qmat2[${c} * 4 * TILE_TXCOLS + 1 ] = qmat2_vec4.y;
115126 qmat2[${c} * 4 * TILE_TXCOLS + 2 ] = qmat2_vec4.z;
116127 qmat2[${c} * 4 * TILE_TXCOLS + 3 ] = qmat2_vec4.w;
117128
118- qmat2_vec4 = ( VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
129+ qmat2_vec4 = VEC4_T(packed_weight_tex & 0x0F);
119130 qmat2[${c} * 4 * TILE_TXCOLS + 4 ] = qmat2_vec4.x;
120131 qmat2[${c} * 4 * TILE_TXCOLS + 5 ] = qmat2_vec4.y;
121132 qmat2[${c} * 4 * TILE_TXCOLS + 6 ] = qmat2_vec4.z;
@@ -156,10 +167,16 @@ void main() {
156167 for (int r = 0 ; r < TILE_ROWS; ++ r) {
157168 VEC4_T scaled_sums;
158169 $for c in range(TILE_TXCOLS):
159- scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0 ] * scales[${c}].x;
160- scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1 ] * scales[${c}].y;
161- scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2 ] * scales[${c}].z;
162- scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3 ] * scales[${c}].w;
170+ $if QUANT_NBITS == 4 :
171+ scaled_sums.x = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0 ] + mat1_accum[r] * - 8.0 ) * scales[${c}].x;
172+ scaled_sums.y = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1 ] + mat1_accum[r] * - 8.0 ) * scales[${c}].y;
173+ scaled_sums.z = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2 ] + mat1_accum[r] * - 8.0 ) * scales[${c}].z;
174+ scaled_sums.w = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3 ] + mat1_accum[r] * - 8.0 ) * scales[${c}].w;
175+ $else :
176+ scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0 ] * scales[${c}].x;
177+ scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1 ] * scales[${c}].y;
178+ scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2 ] * scales[${c}].z;
179+ scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3 ] * scales[${c}].w;
163180
164181 $if OUT_STORAGE == "buffer ":
165182 if (out_row + r < out_sizes.y) {
0 commit comments