Skip to content

Commit d210661

Browse files
ikawrakowIwan Kawrakow
andauthored
Improved IQ1_M quantization (#327)
* Much faster and it looks like better iq1_m quantiation * Cleanup * Minor --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent c01449a commit d210661

File tree

1 file changed

+83
-85
lines changed

1 file changed

+83
-85
lines changed

ggml/src/ggml-quants.c

Lines changed: 83 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -14391,6 +14391,8 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo
1439114391
const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
1439214392

1439314393
float sumqx[4], sumq2[4];
14394+
float sumw1[IQ1M_BLOCK_SIZE+1], sumw2[IQ1M_BLOCK_SIZE+1];
14395+
float sumx1[IQ1M_BLOCK_SIZE+1], sumx2[IQ1M_BLOCK_SIZE+1];
1439414396

1439514397
const int gindex = iq2_data_index(GGML_TYPE_IQ1_M);
1439614398

@@ -14414,82 +14416,45 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo
1441414416
idx[2*j] = j;
1441514417
}
1441614418
qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
14417-
float best_score = -FLT_MIN, scale = 0.f;
14419+
sumw1[0] = sumw2[0] = sumx1[0] = sumx2[0] = 0;
14420+
for (int j = 0; j < block_size; ++j) {
14421+
int i = idx[2*j];
14422+
if (i < block_size/2) {
14423+
sumw1[j+1] = sumw1[j] + weight[i];
14424+
sumx1[j+1] = sumx1[j] + weight[i]*xb[i];
14425+
sumw2[j+1] = sumw2[j];
14426+
sumx2[j+1] = sumx2[j];
14427+
} else {
14428+
sumw2[j+1] = sumw2[j] + weight[i];
14429+
sumx2[j+1] = sumx2[j] + weight[i]*xb[i];
14430+
sumw1[j+1] = sumw1[j];
14431+
sumx1[j+1] = sumx1[j];
14432+
}
14433+
}
14434+
float best_score = 0, scale = 0.f;
1441814435
int besti1 = -1, besti2 = -1, best_k = -1;
1441914436
// 0: +, +
1442014437
// 1: +, -
1442114438
// 2: -, +
1442214439
// 3: -, -
1442314440
for (int i1 = 0; i1 <= block_size; ++i1) {
1442414441
for (int i2 = i1; i2 <= block_size; ++i2) {
14425-
memset(sumqx, 0, 4*sizeof(float));
14426-
memset(sumq2, 0, 4*sizeof(float));
14427-
for (int j = 0; j < i1; ++j) {
14428-
int i = idx[2*j];
14429-
if (i < block_size/2) {
14430-
sumqx[0] += weight[i]*x_p[0]*xb[i];
14431-
sumqx[1] += weight[i]*x_p[0]*xb[i];
14432-
sumqx[2] += weight[i]*x_m[0]*xb[i];
14433-
sumqx[3] += weight[i]*x_m[0]*xb[i];
14434-
sumq2[0] += weight[i]*x_p[0]*x_p[0];
14435-
sumq2[1] += weight[i]*x_p[0]*x_p[0];
14436-
sumq2[2] += weight[i]*x_m[0]*x_m[0];
14437-
sumq2[3] += weight[i]*x_m[0]*x_m[0];
14438-
} else {
14439-
sumqx[0] += weight[i]*x_p[0]*xb[i];
14440-
sumqx[2] += weight[i]*x_p[0]*xb[i];
14441-
sumqx[1] += weight[i]*x_m[0]*xb[i];
14442-
sumqx[3] += weight[i]*x_m[0]*xb[i];
14443-
sumq2[0] += weight[i]*x_p[0]*x_p[0];
14444-
sumq2[2] += weight[i]*x_p[0]*x_p[0];
14445-
sumq2[1] += weight[i]*x_m[0]*x_m[0];
14446-
sumq2[3] += weight[i]*x_m[0]*x_m[0];
14447-
}
14448-
}
14449-
for (int j = i1; j < i2; ++j) {
14450-
int i = idx[2*j];
14451-
if (i < block_size/2) {
14452-
sumqx[0] += weight[i]*x_p[1]*xb[i];
14453-
sumqx[1] += weight[i]*x_p[1]*xb[i];
14454-
sumqx[2] += weight[i]*x_m[1]*xb[i];
14455-
sumqx[3] += weight[i]*x_m[1]*xb[i];
14456-
sumq2[0] += weight[i]*x_p[1]*x_p[1];
14457-
sumq2[1] += weight[i]*x_p[1]*x_p[1];
14458-
sumq2[2] += weight[i]*x_m[1]*x_m[1];
14459-
sumq2[3] += weight[i]*x_m[1]*x_m[1];
14460-
} else {
14461-
sumqx[0] += weight[i]*x_p[1]*xb[i];
14462-
sumqx[2] += weight[i]*x_p[1]*xb[i];
14463-
sumqx[1] += weight[i]*x_m[1]*xb[i];
14464-
sumqx[3] += weight[i]*x_m[1]*xb[i];
14465-
sumq2[0] += weight[i]*x_p[1]*x_p[1];
14466-
sumq2[2] += weight[i]*x_p[1]*x_p[1];
14467-
sumq2[1] += weight[i]*x_m[1]*x_m[1];
14468-
sumq2[3] += weight[i]*x_m[1]*x_m[1];
14469-
}
14470-
}
14471-
for (int j = i2; j < block_size; ++j) {
14472-
int i = idx[2*j];
14473-
if (i < block_size/2) {
14474-
sumqx[0] += weight[i]*x_p[2]*xb[i];
14475-
sumqx[1] += weight[i]*x_p[2]*xb[i];
14476-
sumqx[2] += weight[i]*x_m[2]*xb[i];
14477-
sumqx[3] += weight[i]*x_m[2]*xb[i];
14478-
sumq2[0] += weight[i]*x_p[2]*x_p[2];
14479-
sumq2[1] += weight[i]*x_p[2]*x_p[2];
14480-
sumq2[2] += weight[i]*x_m[2]*x_m[2];
14481-
sumq2[3] += weight[i]*x_m[2]*x_m[2];
14482-
} else {
14483-
sumqx[0] += weight[i]*x_p[2]*xb[i];
14484-
sumqx[2] += weight[i]*x_p[2]*xb[i];
14485-
sumqx[1] += weight[i]*x_m[2]*xb[i];
14486-
sumqx[3] += weight[i]*x_m[2]*xb[i];
14487-
sumq2[0] += weight[i]*x_p[2]*x_p[2];
14488-
sumq2[2] += weight[i]*x_p[2]*x_p[2];
14489-
sumq2[1] += weight[i]*x_m[2]*x_m[2];
14490-
sumq2[3] += weight[i]*x_m[2]*x_m[2];
14491-
}
14492-
}
14442+
sumqx[0] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] +
14443+
(sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2];
14444+
sumqx[1] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] +
14445+
(sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2];
14446+
sumqx[2] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] +
14447+
(sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2];
14448+
sumqx[3] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] +
14449+
(sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2];
14450+
sumq2[0] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] +
14451+
(sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2];
14452+
sumq2[1] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] +
14453+
(sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2];
14454+
sumq2[2] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] +
14455+
(sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2];
14456+
sumq2[3] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] +
14457+
(sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2];
1449314458
for (int k = 0; k < 4; ++k) {
1449414459
if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
1449514460
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
@@ -14524,19 +14489,34 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo
1452414489
the_index[k] = grid_index;
1452514490
}
1452614491
if (!all_on_grid) {
14527-
float sumqx_f = 0, sumq2_f = 0;
14528-
for (int k = 0; k < block_size/8; ++k) {
14529-
if (k == 0) xx = best_k < 2 ? x_p : x_m;
14530-
else xx = best_k%2 == 0 ? x_p : x_m;
14531-
const int8_t * pg = (const int8_t *)(kgrid_q2xs + the_index[k]);
14532-
for (int j = 0; j < 8; ++j) {
14533-
float w = weight[8*k + j];
14534-
float q = xx[(pg[j] - 1)/2];
14535-
sumqx_f += w*q*xb[8*k+j];
14536-
sumq2_f += w*q*q;
14492+
sumqx[0] = sumqx[1] = sumqx[2] = sumqx[3] = 0;
14493+
sumq2[0] = sumq2[1] = sumq2[2] = sumq2[3] = 0;
14494+
for (int j = 0; j < block_size; ++j) {
14495+
float w = weight[j];
14496+
float qp = x_p[L[j]];
14497+
float qm = x_m[L[j]];
14498+
sumqx[0] += w*xb[j]*qp;
14499+
sumq2[0] += w*qp*qp;
14500+
sumqx[3] += w*xb[j]*qm;
14501+
sumq2[3] += w*qm*qm;
14502+
if (j < 8) {
14503+
sumqx[1] += w*xb[j]*qp;
14504+
sumq2[1] += w*qp*qp;
14505+
sumqx[2] += w*xb[j]*qm;
14506+
sumq2[2] += w*qm*qm;
14507+
} else {
14508+
sumqx[2] += w*xb[j]*qp;
14509+
sumq2[2] += w*qp*qp;
14510+
sumqx[1] += w*xb[j]*qm;
14511+
sumq2[1] += w*qm*qm;
14512+
}
14513+
}
14514+
best_score = 0;
14515+
for (int k = 0; k < 4; ++k) {
14516+
if (sumqx[k] > 0 && sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
14517+
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; best_k = k;
1453714518
}
1453814519
}
14539-
if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
1454014520
}
1454114521
*the_scale = scale;
1454214522
*the_shift = best_k;
@@ -14570,6 +14550,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1457014550
const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA};
1457114551
const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
1457214552
const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
14553+
float all_sigma2[QK_K/32];
1457314554

1457414555
iq1m_scale_t s;
1457514556
const float * xx;
@@ -14582,11 +14563,18 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1458214563
float max_scale = 0;
1458314564

1458414565
const float * xbl = x + QK_K*ibl;
14585-
float sumx2 = 0;
14586-
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
14587-
float sigma2 = 2*sumx2/QK_K;
14566+
for (int ib = 0; ib < QK_K/32; ++ib) {
14567+
const float * xb = xbl + 32*ib;
14568+
float sumx2 = 0;
14569+
for (int i = 0; i < 32; ++i) sumx2 += xb[i]*xb[i];
14570+
all_sigma2[ib] = 1.5f*sumx2/32;
14571+
}
14572+
//float sumx2 = 0;
14573+
//for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
14574+
//float sigma2 = 1.5f*sumx2/QK_K;
1458814575

1458914576
for (int ib = 0; ib < QK_K/block_size; ++ib) {
14577+
float sigma2 = all_sigma2[ib/2];
1459014578
const float * xb = xbl + block_size*ib;
1459114579
if (quant_weights) {
1459214580
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
@@ -14595,12 +14583,21 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1459514583
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
1459614584
}
1459714585
float max = fabsf(xb[0]);
14598-
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
14586+
float sumwx = 0;
14587+
for (int i = 1; i < block_size; ++i) {
14588+
float ax = fabsf(xb[i]);
14589+
max = MAX(max, ax);
14590+
sumwx += weight[i]*ax;
14591+
}
1459914592
if (max < GROUP_MAX_EPS_IQ1_M) {
1460014593
scales[ib] = 0;
1460114594
memset(L, 1, block_size);
1460214595
continue;
1460314596
}
14597+
if (sumwx == 0) {
14598+
// weight is zero everywhere where xb is not zero => ignore
14599+
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
14600+
}
1460414601

1460514602
int best_k = -1;
1460614603
iq1m_process_1block(xb, weight, L, &scales[ib], index, &best_k, pairs);
@@ -14621,6 +14618,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1462114618
float id = 1/d;
1462214619
float sumqx_f = 0, sumq2_f = 0;
1462314620
for (int ib = 0; ib < QK_K/block_size; ++ib) {
14621+
float sigma2 = all_sigma2[ib/2];
1462414622
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
1462514623
l = MAX(0, MIN(7, l));
1462614624
sc[ib/4] |= (l << 3*(ib%4));
@@ -14645,7 +14643,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
1464514643
}
1464614644
}
1464714645
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
14648-
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
14646+
s.f16 = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
1464914647
sc[0] |= ((s.u16 & 0x000f) << 12);
1465014648
sc[1] |= ((s.u16 & 0x00f0) << 8);
1465114649
sc[2] |= ((s.u16 & 0x0f00) << 4);

0 commit comments

Comments
 (0)