@@ -2199,8 +2199,9 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
21992199 float rmin, float rdelta, int nstep, bool use_mad) {
22002200 float min = x[0];
22012201 float max = x[0];
2202- float sum_w = weights ? weights[0] : x[0]*x[0];
2203- float sum_x = sum_w * x[0];
2202+ double sum_w = weights ? (double)weights[0] : (double)(x[0]*x[0]);
2203+ double sum_x = sum_w * (double)x[0];
2204+ double sum_x2 = sum_w * (double)x[0] * (double)x[0];
22042205#ifdef HAVE_BUGGY_APPLE_LINKER
22052206 // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
22062207 for (volatile int i = 1; i < n; ++i) {
@@ -2210,8 +2211,9 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
22102211 if (x[i] < min) min = x[i];
22112212 if (x[i] > max) max = x[i];
22122213 float w = weights ? weights[i] : x[i]*x[i];
2213- sum_w += w;
2214- sum_x += w * x[i];
2214+ sum_w += (double)w;
2215+ sum_x += (double)w * (double)x[i];
2216+ sum_x2 += (double)w * (double)x[i] * (double)x[i];
22152217 }
22162218 if (min > 0) {
22172219 min = 0;
@@ -2223,13 +2225,13 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
22232225 }
22242226 float iscale = nmax/(max - min);
22252227 float scale = 1/iscale;
2226- float best_mad = 0;
2228+ double best_mad = 0;
22272229 for (int i = 0; i < n; ++i) {
22282230 int l = nearest_int(iscale*(x[i] - min));
22292231 L[i] = MAX(0, MIN(nmax, l));
2230- float diff = scale * L[i] + min - x[i];
2231- diff = use_mad ? fabsf (diff) : diff*diff;
2232- float w = weights ? weights[i] : x[i]*x[i];
2232+ double diff = (double) scale * L[i] + (double) min - (double) x[i];
2233+ diff = use_mad ? fabs (diff) : diff*diff;
2234+ double w = weights ? (double) weights[i] : (double)( x[i]*x[i]) ;
22332235 best_mad += w * diff;
22342236 }
22352237 if (nstep < 1) {
@@ -2238,30 +2240,35 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
22382240 }
22392241 for (int is = 0; is <= nstep; ++is) {
22402242 iscale = (rmin + rdelta*is + nmax)/(max - min);
2241- float sum_l = 0, sum_l2 = 0, sum_xl = 0;
2243+ double sum_l = 0, sum_l2 = 0, sum_xl = 0;
22422244 for (int i = 0; i < n; ++i) {
22432245 int l = nearest_int(iscale*(x[i] - min));
22442246 l = MAX(0, MIN(nmax, l));
22452247 Laux[i] = l;
22462248 float w = weights ? weights[i] : x[i]*x[i];
2247- sum_l += w*l;
2248- sum_l2 += w*l*l;
2249- sum_xl += w*l*x[i];
2249+ sum_l += (double) w*l;
2250+ sum_l2 += (double) w*l*l;
2251+ sum_xl += (double) w*l*(double) x[i];
22502252 }
2251- float D = sum_w * sum_l2 - sum_l * sum_l;
2253+ double D = sum_w * sum_l2 - sum_l * sum_l;
22522254 if (D > 0) {
2253- float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
2254- float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
2255+ double this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
2256+ double this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
22552257 if (this_min > 0) {
22562258 this_min = 0;
22572259 this_scale = sum_xl / sum_l2;
22582260 }
2259- float mad = 0;
2260- for (int i = 0; i < n; ++i) {
2261- float diff = this_scale * Laux[i] + this_min - x[i];
2262- diff = use_mad ? fabsf(diff) : diff*diff;
2263- float w = weights ? weights[i] : x[i]*x[i];
2264- mad += w * diff;
2261+ double mad = 0;
2262+ if (use_mad) {
2263+ for (int i = 0; i < n; ++i) {
2264+ double diff = (double)this_scale * Laux[i] + (double)this_min - (double)x[i];
2265+ diff = fabs(diff);
2266+ double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
2267+ mad += w * diff;
2268+ }
2269+ } else {
2270+ mad = sum_x2 - 2*this_scale*sum_xl - 2*this_min*sum_x + 2*this_scale*this_min*sum_l
2271+ + this_scale*this_scale*sum_l2 + this_min*this_min*sum_w;
22652272 }
22662273 if (mad < best_mad) {
22672274 for (int i = 0; i < n; ++i) {
@@ -2273,6 +2280,57 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f
22732280 }
22742281 }
22752282 }
2283+ if (use_mad) {
2284+ *the_min = -min;
2285+ return scale;
2286+ }
2287+
2288+ double sum_l = 0, sum_l2 = 0, sum_xl = 0;
2289+ for (int i = 0; i < n; ++i) {
2290+ int l = L[i];
2291+ double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
2292+ sum_l += w*l;
2293+ sum_l2 += w*l*l;
2294+ sum_xl += w*l*(double)x[i];
2295+ }
2296+ double best = 2*(double)scale*sum_xl + 2*(double)min*sum_x - 2*(double)scale*(double)min*sum_l
2297+ - (double)scale*(double)scale*sum_l2 - (double)min*(double)min*sum_w;
2298+ int last_j = -1, last_dir = 0;
2299+ for (int itry = 0; itry < nmax*n; ++itry) {
2300+ float gmax = 0;
2301+ int best_j = -1, dir = 0;
2302+ for (int j = 0; j < n; ++j) {
2303+ float g = x[j] - scale*L[j] - min;
2304+ if (g > 0 && L[j] < nmax && g > gmax) {
2305+ gmax = g; best_j = j; dir = 1;
2306+ }
2307+ else if (g < 0 && L[j] > 0 && -g > gmax) {
2308+ gmax = -g; best_j = j; dir = -1;
2309+ }
2310+ }
2311+ if (best_j < 0 || (best_j == last_j && dir == -last_dir)) break;
2312+ double w = weights ? (double)weights[best_j] : (double)(x[best_j]*x[best_j]);
2313+ sum_l += w*dir;
2314+ sum_l2 += w*(2*L[best_j]*dir + 1);
2315+ sum_xl += w*(double)x[best_j]*dir;
2316+ double D = (double)sum_w * sum_l2 - sum_l * sum_l;
2317+ if (D <= 0) break;
2318+ double this_scale = ((double)sum_w * sum_xl - (double)sum_x * sum_l)/D;
2319+ double this_min = (sum_l2 * (double)sum_x - sum_l * sum_xl)/D;
2320+ if (this_min > 0) {
2321+ this_min = 0;
2322+ this_scale = sum_xl / sum_l2;
2323+ }
2324+ if (this_scale < 0) break;
2325+ double score = 2*this_scale*sum_xl + 2*this_min*(double)sum_x - 2*this_scale*this_min*sum_l
2326+ - this_scale*this_scale*sum_l2 - this_min*this_min*(double)sum_w;
2327+ if (score <= best) break;
2328+ best = score;
2329+ scale = this_scale;
2330+ min = this_min;
2331+ L[best_j] += dir;
2332+ last_j = best_j; last_dir = dir;
2333+ }
22762334 *the_min = -min;
22772335 return scale;
22782336}
@@ -2354,7 +2412,6 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
23542412 GGML_ASSERT(quant_weights);
23552413 assert(k % QK_K == 0);
23562414 const int nb = k / QK_K;
2357- const bool requantize = true;
23582415
23592416 uint8_t L[QK_K];
23602417 uint8_t Laux[16];
@@ -2368,39 +2425,33 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
23682425 memset(sw, 0, QK_K/16*sizeof(float));
23692426 float sumx2 = 0;
23702427 for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
2371- float sigma2 = sumx2/QK_K;
2428+ float sigma2 = 0.75f* sumx2/QK_K;
23722429 for (int j = 0; j < QK_K/16; ++j) {
23732430 const float * restrict qw = quant_weights + QK_K * i + 16*j;
23742431 for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
23752432 for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
23762433 scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
23772434 }
23782435
2379- float dm, mm;
2380- dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
2381- mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
2436+ float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
2437+ float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
23822438
23832439 y[i].d = GGML_FP32_TO_FP16(dm);
23842440 y[i].dmin = GGML_FP32_TO_FP16(mm);
2385- dm = GGML_FP16_TO_FP32(y[i].d);
2386- mm = GGML_FP16_TO_FP32(y[i].dmin);
23872441
23882442 for (int j = 0; j < QK_K/16; ++j) {
2389- y[i].scales[j] = Ls[j] | (Lm[j] << 4);
2390- }
2391-
2392- if (requantize) {
2393- for (int j = 0; j < QK_K/16; ++j) {
2394- const float d = dm * (y[i].scales[j] & 0xF);
2395- if (!d) continue;
2396- const float m = mm * (y[i].scales[j] >> 4);
2397- for (int ii = 0; ii < 16; ++ii) {
2398- int l = nearest_int((x[16*j + ii] + m)/d);
2399- l = MAX(0, MIN(3, l));
2400- L[16*j + ii] = l;
2401- }
2443+ float d = dm*Ls[j];
2444+ float m = mm*Lm[j];
2445+ float id = d ? 1/d : 0.f;
2446+ for (int l = 0; l < QK_K/16; ++l) {
2447+ int q = nearest_int((x[16*j + l] + m)*id);
2448+ q = MAX(0, MIN(3, q));
2449+ L[16*j + l] = q;
24022450 }
24032451 }
2452+ for (int j = 0; j < QK_K/16; ++j) {
2453+ y[i].scales[j] = Ls[j] | (Lm[j] << 4);
2454+ }
24042455
24052456 for (int j = 0; j < QK_K; j += 128) {
24062457 for (int l = 0; l < 32; ++l) {
0 commit comments