Skip to content

Commit 190e786

Browse files
ikawrakowIwan Kawrakow
andauthored
Quantization improvements (2) (#302)
* iq3_k: slightly better quantization Not much of a difference for most models, but this change avoids what it looks like a catastrophic failure for DeepSeek-Lite (PPL is now 7.041 vs 7.314 on main). * Small improvement for type-1 quants --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent b07a337 commit 190e786

File tree

2 files changed

+141
-54
lines changed

2 files changed

+141
-54
lines changed

ggml/src/ggml-quants.c

Lines changed: 92 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2199,8 +2199,9 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
21992199
float rmin, float rdelta, int nstep, bool use_mad) {
22002200
float min = x[0];
22012201
float max = x[0];
2202-
float sum_w = weights ? weights[0] : x[0]*x[0];
2203-
float sum_x = sum_w * x[0];
2202+
double sum_w = weights ? (double)weights[0] : (double)(x[0]*x[0]);
2203+
double sum_x = sum_w * (double)x[0];
2204+
double sum_x2 = sum_w * (double)x[0] * (double)x[0];
22042205
#ifdef HAVE_BUGGY_APPLE_LINKER
22052206
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
22062207
for (volatile int i = 1; i < n; ++i) {
@@ -2210,8 +2211,9 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
22102211
if (x[i] < min) min = x[i];
22112212
if (x[i] > max) max = x[i];
22122213
float w = weights ? weights[i] : x[i]*x[i];
2213-
sum_w += w;
2214-
sum_x += w * x[i];
2214+
sum_w += (double)w;
2215+
sum_x += (double)w * (double)x[i];
2216+
sum_x2 += (double)w * (double)x[i] * (double)x[i];
22152217
}
22162218
if (min > 0) {
22172219
min = 0;
@@ -2223,13 +2225,13 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
22232225
}
22242226
float iscale = nmax/(max - min);
22252227
float scale = 1/iscale;
2226-
float best_mad = 0;
2228+
double best_mad = 0;
22272229
for (int i = 0; i < n; ++i) {
22282230
int l = nearest_int(iscale*(x[i] - min));
22292231
L[i] = MAX(0, MIN(nmax, l));
2230-
float diff = scale * L[i] + min - x[i];
2231-
diff = use_mad ? fabsf(diff) : diff*diff;
2232-
float w = weights ? weights[i] : x[i]*x[i];
2232+
double diff = (double)scale * L[i] + (double)min - (double)x[i];
2233+
diff = use_mad ? fabs(diff) : diff*diff;
2234+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
22332235
best_mad += w * diff;
22342236
}
22352237
if (nstep < 1) {
@@ -2238,30 +2240,35 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
22382240
}
22392241
for (int is = 0; is <= nstep; ++is) {
22402242
iscale = (rmin + rdelta*is + nmax)/(max - min);
2241-
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
2243+
double sum_l = 0, sum_l2 = 0, sum_xl = 0;
22422244
for (int i = 0; i < n; ++i) {
22432245
int l = nearest_int(iscale*(x[i] - min));
22442246
l = MAX(0, MIN(nmax, l));
22452247
Laux[i] = l;
22462248
float w = weights ? weights[i] : x[i]*x[i];
2247-
sum_l += w*l;
2248-
sum_l2 += w*l*l;
2249-
sum_xl += w*l*x[i];
2249+
sum_l += (double)w*l;
2250+
sum_l2 += (double)w*l*l;
2251+
sum_xl += (double)w*l*(double)x[i];
22502252
}
2251-
float D = sum_w * sum_l2 - sum_l * sum_l;
2253+
double D = sum_w * sum_l2 - sum_l * sum_l;
22522254
if (D > 0) {
2253-
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
2254-
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
2255+
double this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
2256+
double this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
22552257
if (this_min > 0) {
22562258
this_min = 0;
22572259
this_scale = sum_xl / sum_l2;
22582260
}
2259-
float mad = 0;
2260-
for (int i = 0; i < n; ++i) {
2261-
float diff = this_scale * Laux[i] + this_min - x[i];
2262-
diff = use_mad ? fabsf(diff) : diff*diff;
2263-
float w = weights ? weights[i] : x[i]*x[i];
2264-
mad += w * diff;
2261+
double mad = 0;
2262+
if (use_mad) {
2263+
for (int i = 0; i < n; ++i) {
2264+
double diff = (double)this_scale * Laux[i] + (double)this_min - (double)x[i];
2265+
diff = fabs(diff);
2266+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
2267+
mad += w * diff;
2268+
}
2269+
} else {
2270+
mad = sum_x2 - 2*this_scale*sum_xl - 2*this_min*sum_x + 2*this_scale*this_min*sum_l
2271+
+ this_scale*this_scale*sum_l2 + this_min*this_min*sum_w;
22652272
}
22662273
if (mad < best_mad) {
22672274
for (int i = 0; i < n; ++i) {
@@ -2273,6 +2280,57 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
22732280
}
22742281
}
22752282
}
2283+
if (use_mad) {
2284+
*the_min = -min;
2285+
return scale;
2286+
}
2287+
2288+
double sum_l = 0, sum_l2 = 0, sum_xl = 0;
2289+
for (int i = 0; i < n; ++i) {
2290+
int l = L[i];
2291+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
2292+
sum_l += w*l;
2293+
sum_l2 += w*l*l;
2294+
sum_xl += w*l*(double)x[i];
2295+
}
2296+
double best = 2*(double)scale*sum_xl + 2*(double)min*sum_x - 2*(double)scale*(double)min*sum_l
2297+
- (double)scale*(double)scale*sum_l2 - (double)min*(double)min*sum_w;
2298+
int last_j = -1, last_dir = 0;
2299+
for (int itry = 0; itry < nmax*n; ++itry) {
2300+
float gmax = 0;
2301+
int best_j = -1, dir = 0;
2302+
for (int j = 0; j < n; ++j) {
2303+
float g = x[j] - scale*L[j] - min;
2304+
if (g > 0 && L[j] < nmax && g > gmax) {
2305+
gmax = g; best_j = j; dir = 1;
2306+
}
2307+
else if (g < 0 && L[j] > 0 && -g > gmax) {
2308+
gmax = -g; best_j = j; dir = -1;
2309+
}
2310+
}
2311+
if (best_j < 0 || (best_j == last_j && dir == -last_dir)) break;
2312+
double w = weights ? (double)weights[best_j] : (double)(x[best_j]*x[best_j]);
2313+
sum_l += w*dir;
2314+
sum_l2 += w*(2*L[best_j]*dir + 1);
2315+
sum_xl += w*(double)x[best_j]*dir;
2316+
double D = (double)sum_w * sum_l2 - sum_l * sum_l;
2317+
if (D <= 0) break;
2318+
double this_scale = ((double)sum_w * sum_xl - (double)sum_x * sum_l)/D;
2319+
double this_min = (sum_l2 * (double)sum_x - sum_l * sum_xl)/D;
2320+
if (this_min > 0) {
2321+
this_min = 0;
2322+
this_scale = sum_xl / sum_l2;
2323+
}
2324+
if (this_scale < 0) break;
2325+
double score = 2*this_scale*sum_xl + 2*this_min*(double)sum_x - 2*this_scale*this_min*sum_l
2326+
- this_scale*this_scale*sum_l2 - this_min*this_min*(double)sum_w;
2327+
if (score <= best) break;
2328+
best = score;
2329+
scale = this_scale;
2330+
min = this_min;
2331+
L[best_j] += dir;
2332+
last_j = best_j; last_dir = dir;
2333+
}
22762334
*the_min = -min;
22772335
return scale;
22782336
}
@@ -2354,7 +2412,6 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
23542412
GGML_ASSERT(quant_weights);
23552413
assert(k % QK_K == 0);
23562414
const int nb = k / QK_K;
2357-
const bool requantize = true;
23582415

23592416
uint8_t L[QK_K];
23602417
uint8_t Laux[16];
@@ -2368,39 +2425,33 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
23682425
memset(sw, 0, QK_K/16*sizeof(float));
23692426
float sumx2 = 0;
23702427
for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
2371-
float sigma2 = sumx2/QK_K;
2428+
float sigma2 = 0.75f*sumx2/QK_K;
23722429
for (int j = 0; j < QK_K/16; ++j) {
23732430
const float * restrict qw = quant_weights + QK_K * i + 16*j;
23742431
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
23752432
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
23762433
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
23772434
}
23782435

2379-
float dm, mm;
2380-
dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
2381-
mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
2436+
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
2437+
float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
23822438

23832439
y[i].d = GGML_FP32_TO_FP16(dm);
23842440
y[i].dmin = GGML_FP32_TO_FP16(mm);
2385-
dm = GGML_FP16_TO_FP32(y[i].d);
2386-
mm = GGML_FP16_TO_FP32(y[i].dmin);
23872441

23882442
for (int j = 0; j < QK_K/16; ++j) {
2389-
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
2390-
}
2391-
2392-
if (requantize) {
2393-
for (int j = 0; j < QK_K/16; ++j) {
2394-
const float d = dm * (y[i].scales[j] & 0xF);
2395-
if (!d) continue;
2396-
const float m = mm * (y[i].scales[j] >> 4);
2397-
for (int ii = 0; ii < 16; ++ii) {
2398-
int l = nearest_int((x[16*j + ii] + m)/d);
2399-
l = MAX(0, MIN(3, l));
2400-
L[16*j + ii] = l;
2401-
}
2443+
float d = dm*Ls[j];
2444+
float m = mm*Lm[j];
2445+
float id = d ? 1/d : 0.f;
2446+
for (int l = 0; l < QK_K/16; ++l) {
2447+
int q = nearest_int((x[16*j + l] + m)*id);
2448+
q = MAX(0, MIN(3, q));
2449+
L[16*j + l] = q;
24022450
}
24032451
}
2452+
for (int j = 0; j < QK_K/16; ++j) {
2453+
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
2454+
}
24042455

24052456
for (int j = 0; j < QK_K; j += 128) {
24062457
for (int l = 0; l < 32; ++l) {

ggml/src/iqk/iqk_quantize.cpp

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,12 +1555,13 @@ inline int best_index_iq3nl(const int8_t * values, float x) {
15551555

15561556
static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) {
15571557

1558-
const int ntry = 5;
1558+
constexpr int ntry = 3;
15591559

15601560
block_iq3_k * y = (block_iq3_k *)vy;
15611561

15621562
float scales[QK_K/16];
15631563
float weight[16];
1564+
uint8_t L[16];
15641565

15651566
const int8_t * shifted_values = iq3nl_values + 8;
15661567

@@ -1620,7 +1621,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
16201621
}
16211622
bool is_shifted = false;
16221623
for (int itry = -ntry; itry <= ntry; ++itry) {
1623-
id = (itry + iq3nl_values[0])/max;
1624+
id = (2*itry + iq3nl_values[0])/max;
16241625
sumqx_p = sumq2_p = 0;
16251626
sumqx_m = sumq2_m = 0;
16261627
for (int j = 0; j < 16; ++j) {
@@ -1641,7 +1642,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
16411642
if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) {
16421643
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false;
16431644
}
1644-
id = (itry + shifted_values[0])/max;
1645+
id = (2*itry + shifted_values[0])/max;
16451646
sumqx_p = sumq2_p = 0;
16461647
sumqx_m = sumq2_m = 0;
16471648
for (int j = 0; j < 16; ++j) {
@@ -1663,20 +1664,55 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c
16631664
d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true;
16641665
}
16651666
}
1666-
if (d) {
1667-
const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values;
1668-
float sumqx = 0, sumq2 = 0;
1669-
id = 1/d;
1667+
if (!d) {
1668+
scales[ib] = 0; continue;
1669+
}
1670+
1671+
const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values;
1672+
float sumqx = 0, sumq2 = 0;
1673+
id = 1/d;
1674+
for (int j = 0; j < 16; ++j) {
1675+
float w = weight[j];
1676+
float al = id*xb[j];
1677+
int l = best_index_iq3nl(block_values, al);
1678+
L[j] = l;
1679+
float q = block_values[l];
1680+
sumqx += w*q*xb[j];
1681+
sumq2 += w*q*q;
1682+
}
1683+
if (sumq2 > 0) d = sumqx/sumq2;
1684+
1685+
float best_d = d;
1686+
for (int iter = 0; iter < 128; ++iter) {
1687+
float gmax = 0;
1688+
int best_j = -1, dir = 0;
16701689
for (int j = 0; j < 16; ++j) {
16711690
float w = weight[j];
1672-
float al = id*xb[j];
1673-
int l = best_index_iq3nl(block_values, al);
1674-
float q = block_values[l];
1675-
sumqx += w*q*xb[j];
1676-
sumq2 += w*q*q;
1691+
float g = d * w * (xb[j] - d*block_values[L[j]]);
1692+
if (g > 0 && L[j] < 7) {
1693+
if (g > gmax) {
1694+
gmax = g; best_j = j; dir = 1;
1695+
}
1696+
}
1697+
else if (g < 0 && L[j] > 0) {
1698+
if (-g > gmax) {
1699+
gmax = -g; best_j = j; dir = -1;
1700+
}
1701+
}
16771702
}
1678-
if (sumq2 > 0) d = sumqx/sumq2;
1703+
if (best_j < 0) break;
1704+
1705+
float w = weight[best_j];
1706+
sumqx += w*xb[best_j]*(block_values[L[best_j]+dir] - block_values[L[best_j]]);
1707+
sumq2 += w*(block_values[L[best_j]+dir]*block_values[L[best_j]+dir] - block_values[L[best_j]]*block_values[L[best_j]]);
1708+
L[best_j] += dir;
1709+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
1710+
best_d = sumqx/sumq2; best = best_d*sumqx;
1711+
}
1712+
else if (iter > 8) break;
1713+
16791714
}
1715+
16801716
scales[ib] = d;
16811717

16821718
if (is_shifted) extra |= (1 << ib);

0 commit comments

Comments
 (0)