Skip to content

Commit d85c644

Browse files
ikawrakowIwan Kawrakow
andauthored
New IQ2_KT, IQ3_KT and IQ4_KT, V2 (#529)
* New iq4_kt trellis The new trellis generates int8_t values via sum_as_uint8_t[(ka * idx + kb) & 0x3f33f3f3f] - 126. CUDA dequantize works. AVX2 case Ny > 32 works, and we get 273 t/s for L3-8B. PPL is on par or even slightly lower than original QTIP trellis. * Something is not working with the AVX2 dot product * New iq4_kt: CUDA MMVQ * New iq4_kt: CUDA MMQ * For now have only iq4_kt use the new trellis * Fix iq2_kt that got broken along the way * New iq4_kt: AVX2 dot product finally works We get 13.6 t/s vs 8.4 t/s with the f16 trellis and f32 arithmetic. Still somewhat slower than other quants, but no longer pathetic. * New iq4_kt: fix vanilla AVX2 * New iq4_kt: NEON implementation We get very respectable PP-512 = 120 t/s. TG-128 is pathetic at 5.3 t/s, so 20+% slower than the f16 variant. * New iq4_kt: slightly faster NEON * New iq4_kt: slightly faster NEON * New iq4_kt: faster NEON We are now at 9.4 t/s, up from 6.6 t/s for the f16 trellis. * Minor * New iq4_kt trellis: not working Metal implementation * Remove the extra 4 bytes of row meta data that is no longer used * Cleanup * Adding forgottent file * Switching iq2_kt to new trellis - CUDA MMQ * New iq2_kt: CUDA GEMV * New iq2_kt: AVX2 dequantize * New iq2_kt: AVX2 GEMM/GEMV * Adding forgotten file * New iq2_kt: NEON GEMM/GEMV * New iq2_kt: slightly faster NEON GEMM * New iq2_kt: Metal - very slow. It seems Apple Silicon cannot quickly add 4 8-bit ints. Or I don't know how to do it - but I didn't find anything in the Metal Shading Language Specification. So, performance is quite a bit worse than the original trellis. * Add missing break * Trying @louiehelm's multiplier * CPU * iq3_kt: use integer trellis + CUDA dequantize and MMVQ * iq3_kt: MMQ * iq3_kt: AVX2 GEMM * iq3_kt: AVX2 GEMV * The trellis quants now need super-blocks of 256, so we need a check --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent c410cc7 commit d85c644

File tree

16 files changed

+1668
-132
lines changed

16 files changed

+1668
-132
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,20 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
578578
static constexpr int qi = QI4_XS;
579579
};
580580

581+
template<>
582+
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KT> {
583+
static constexpr int qk = QK_K;
584+
static constexpr int qr = QR4_XS;
585+
static constexpr int qi = QI4_XS;
586+
};
587+
588+
template<>
589+
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
590+
static constexpr int qk = QK_K;
591+
static constexpr int qr = QR4_XS;
592+
static constexpr int qi = QI4_XS;
593+
};
594+
581595
template<>
582596
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
583597
static constexpr int qk = QK_K;

ggml/src/ggml-cuda/convert.cu

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,12 @@ inline __device__ int nearest_int(float fval) {
340340
return (i & 0x007fffff) - 0x00400000;
341341
}
342342

343+
int __device__ __forceinline__ trellis_next_int(uint32_t& val) {
344+
constexpr uint32_t ka = 0xCBAC1FED;
345+
val = ka*val;
346+
return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126);
347+
}
348+
343349
float __device__ __forceinline__ trellis_next(uint32_t& val) {
344350
constexpr uint32_t ka = 89226354;
345351
constexpr uint32_t kb = 64248484;
@@ -367,9 +373,9 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst
367373
dst_t * y = yy + ii*QK_K + 8*ib;
368374
const uint16_t * ql = (const uint16_t *)x[i].ql;
369375
uint32_t idx = ql[ib] + 4096;
370-
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
376+
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f;
371377
for (int j = 0; j < 8; ++j) {
372-
y[j] = dl * trellis_next(idx);
378+
y[j] = dl * trellis_next_int(idx);
373379
}
374380
}
375381

@@ -388,10 +394,10 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst
388394
dst_t * y = yy + ii*QK_K + 8*ib;
389395
const uint16_t * ql = (const uint16_t *)x[i].ql;
390396
uint32_t idx = ql[ib] + 4096;
391-
const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
397+
const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 1.01f; //1.015f;
392398
uint8_t mask = 1 << (ib/4);
393399
for (int j = 0; j < 8; ++j) {
394-
y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
400+
y[j] = dl * std::abs(trellis_next_int(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
395401
}
396402
}
397403

@@ -401,9 +407,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
401407
int64_t ii = blockIdx.x;
402408
int64_t row = (QK_K * ii) / n_per_row;
403409
const float * dptr = (const float *)((const char *)vx + row * row_size);
404-
float scale = dptr[0] * 31.75f * 1.01f;
405-
float row_av = dptr[1];
406-
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
410+
float scale = dptr[0] * 1.00f;
411+
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1);
407412
const int64_t i = ii - (row*n_per_row)/QK_K;
408413

409414
constexpr int kNumGroups = 64;
@@ -423,8 +428,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
423428
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
424429
const float dl = scale * ls;
425430
for (int j = 0; j < 4; ++j) {
426-
y[j+0] = dl * trellis_next(idx1) + row_av;
427-
y[j+4] = dl * trellis_next(idx2) + row_av;
431+
y[j+0] = dl * trellis_next_int(idx1);
432+
y[j+4] = dl * trellis_next_int(idx2);
428433
}
429434
}
430435

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,119 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
433433
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
434434
}
435435

436+
__device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
437+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
438+
439+
constexpr uint32_t ka = 0xCBAC1FED;
440+
constexpr uint32_t km = 0x3f3f3f3f;
441+
442+
float scale = *(const float *)vbq;
443+
const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + sizeof(float)) + kbx;
444+
445+
// iqs is 0...28
446+
const int ib32 = iqs/4; // Why iqs/4 ?
447+
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
448+
//const int8_t * q8 = bq8_1[ib32].qs;
449+
const int ls = (bq4->qs[ib32] & 0xff) >> 1;
450+
const float dl = scale * (ls - 64);
451+
const uint32_t idx0 = ((bq4->qs[ib32] & 1) << 15) + 4096;
452+
auto ql = (const uint8_t *)(bq4->qs + 8);
453+
auto qh = ql + 64;
454+
ql += 8*ib32;
455+
qh += 8*(ib32%4);
456+
const int shift1 = 8 - 4*(ib32/4);
457+
int sumi = 0;
458+
for (int j = 0; j < 8; ++j) {
459+
const uint32_t sh = bq4->qs[ib32] >> (8 + 3*j);
460+
uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0;
461+
int v4 = 0;
462+
for (int k = 0; k < 4; ++k) {
463+
val *= ka;
464+
//int s = val & km;
465+
//sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126);
466+
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
467+
}
468+
sumi = ggml_cuda_dp4a(v4, q8[j], sumi);
469+
}
470+
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
471+
}
472+
473+
__device__ __forceinline__ void vec_dot_iq2_kt_q8_1(
474+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
475+
476+
constexpr uint32_t ka = 0xCBAC1FED;
477+
constexpr uint32_t km = 0x3f3f3f3f;
478+
479+
float scale = *(const float *)vbq;
480+
const block_iq2_kt * bq2 = (const block_iq2_kt *)((const char *)vbq + sizeof(float)) + kbx;
481+
482+
// iqs is 0...28
483+
const int ib32 = iqs/4;
484+
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
485+
const int ls = iq4k_values[(bq2->scales[ib32%4] >> 4*(ib32/4)) & 0xf];
486+
const float dl = scale * ls * 1.05f;
487+
auto ql = (const uint16_t *)bq2->ql;
488+
int sumi = 0;
489+
for (int j = 0; j < 4; ++j) {
490+
uint32_t val = ql[4*ib32+j] + 4096;
491+
int v4 = 0;
492+
for (int k = 0; k < 4; ++k) {
493+
val *= ka;
494+
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
495+
}
496+
sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
497+
v4 = 0;
498+
for (int k = 0; k < 4; ++k) {
499+
val *= ka;
500+
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
501+
}
502+
sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
503+
}
504+
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
505+
}
506+
507+
__device__ __forceinline__ void vec_dot_iq3_kt_q8_1(
508+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
509+
510+
constexpr uint32_t ka = 0xCBAC1FED;
511+
constexpr uint32_t km = 0x3f3f3f3f;
512+
513+
float scale = *(const float *)vbq;
514+
const block_iq3_kt * bq3 = (const block_iq3_kt *)((const char *)vbq + sizeof(float)) + kbx;
515+
516+
// iqs is 0...28
517+
const int ib32 = iqs/4;
518+
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
519+
const int ls = (bq3->scales[ib32%4] >> 4*(ib32/4)) & 0xf;
520+
const float dl = scale * ls * 1.015f;
521+
auto ql = (const uint16_t *)bq3->ql;
522+
uint32_t mask = 0x01010101 << ib32;
523+
const uint32_t * qh = (const uint32_t *)bq3->qh;
524+
int sumi = 0;
525+
for (int j = 0; j < 4; ++j) {
526+
uint32_t val = ql[4*ib32+j] + 4096;
527+
int v4 = 0;
528+
for (int k = 0; k < 4; ++k) {
529+
val *= ka;
530+
int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126));
531+
v4 |= q << 8*k;
532+
}
533+
uint32_t signs = __vcmpne4(qh[2*j+0] & mask, 0);
534+
v4 = __vsub4(v4 ^ signs, signs);
535+
sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
536+
v4 = 0;
537+
for (int k = 0; k < 4; ++k) {
538+
val *= ka;
539+
int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126));
540+
v4 |= q << 8*k;
541+
}
542+
signs = __vcmpne4(qh[2*j+1] & mask, 0);
543+
v4 = __vsub4(v4 ^ signs, signs);
544+
sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
545+
}
546+
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
547+
}
548+
436549
#define VDR_IQ4_KSS_Q8_1_MMVQ 4
437550
#define VDR_IQ4_KSS_Q8_1_MMQ 4
438551

@@ -1217,6 +1330,30 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
12171330
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
12181331
}
12191332

1333+
void mul_mat_vec_iq4_kt_q8_1_cuda(
1334+
const void * vx, const void * vy, float * dst, const char * ids_data,
1335+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
1336+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
1337+
1338+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
1339+
}
1340+
1341+
void mul_mat_vec_iq2_kt_q8_1_cuda(
1342+
const void * vx, const void * vy, float * dst, const char * ids_data,
1343+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
1344+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
1345+
1346+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
1347+
}
1348+
1349+
void mul_mat_vec_iq3_kt_q8_1_cuda(
1350+
const void * vx, const void * vy, float * dst, const char * ids_data,
1351+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
1352+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
1353+
1354+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq3_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
1355+
}
1356+
12201357
void mul_mat_vec_iq4_kss_q8_1_cuda(
12211358
const void * vx, const void * vy, float * dst, const char * ids_data,
12221359
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,

ggml/src/ggml-cuda/iqk_mmvq.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,18 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda(
100100
const void * vx, const void * vy, float * dst, const char * ids_data,
101101
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
102102
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
103+
104+
void mul_mat_vec_iq2_kt_q8_1_cuda(
105+
const void * vx, const void * vy, float * dst, const char * ids_data,
106+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
107+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
108+
109+
void mul_mat_vec_iq3_kt_q8_1_cuda(
110+
const void * vx, const void * vy, float * dst, const char * ids_data,
111+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
112+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
113+
114+
void mul_mat_vec_iq4_kt_q8_1_cuda(
115+
const void * vx, const void * vy, float * dst, const char * ids_data,
116+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
117+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);

ggml/src/ggml-cuda/mmq.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ void ggml_cuda_op_mul_mat_q(
100100
case GGML_TYPE_IQ4_KS_R4:
101101
mul_mat_q_case<GGML_TYPE_IQ4_KS_R4>(ctx, args, stream);
102102
break;
103+
case GGML_TYPE_IQ4_KT:
104+
mul_mat_q_case<GGML_TYPE_IQ4_KT>(ctx, args, stream);
105+
break;
106+
case GGML_TYPE_IQ2_KT:
107+
mul_mat_q_case<GGML_TYPE_IQ2_KT>(ctx, args, stream);
108+
break;
109+
case GGML_TYPE_IQ3_KT:
110+
mul_mat_q_case<GGML_TYPE_IQ3_KT>(ctx, args, stream);
111+
break;
103112
case GGML_TYPE_IQ5_KS:
104113
mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream);
105114
break;
@@ -172,6 +181,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
172181
case GGML_TYPE_IQ4_K:
173182
case GGML_TYPE_IQ5_K:
174183
case GGML_TYPE_IQ6_K:
184+
case GGML_TYPE_IQ2_KT:
185+
case GGML_TYPE_IQ3_KT:
186+
case GGML_TYPE_IQ4_KT:
175187
mmq_supported = true;
176188
break;
177189
default:

0 commit comments

Comments
 (0)