Skip to content

Commit cf25644

Browse files
author
Iwan Kawrakow
committed
iq3_kt: use integer trellis + CUDA dequantize and MMVQ
1 parent 6408b94 commit cf25644

File tree

6 files changed

+88
-18
lines changed

6 files changed

+88
-18
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,20 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
578578
static constexpr int qi = QI4_XS;
579579
};
580580

581+
template<>
582+
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_KT> {
583+
static constexpr int qk = QK_K;
584+
static constexpr int qr = QR4_XS;
585+
static constexpr int qi = QI4_XS;
586+
};
587+
588+
template<>
589+
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
590+
static constexpr int qk = QK_K;
591+
static constexpr int qr = QR4_XS;
592+
static constexpr int qi = QI4_XS;
593+
};
594+
581595
template<>
582596
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
583597
static constexpr int qk = QK_K;
@@ -648,13 +662,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
648662
static constexpr int qi = QI3_S;
649663
};
650664

651-
template<>
652-
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
653-
static constexpr int qk = QK_K;
654-
static constexpr int qr = QR4_XS;
655-
static constexpr int qi = QI4_XS;
656-
};
657-
658665
//////////////////////
659666

660667
struct ggml_cuda_device_info {

ggml/src/ggml-cuda/convert.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,10 +394,10 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst
394394
dst_t * y = yy + ii*QK_K + 8*ib;
395395
const uint16_t * ql = (const uint16_t *)x[i].ql;
396396
uint32_t idx = ql[ib] + 4096;
397-
const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
397+
const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 1.01f; //1.015f;
398398
uint8_t mask = 1 << (ib/4);
399399
for (int j = 0; j < 8; ++j) {
400-
y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
400+
y[j] = dl * std::abs(trellis_next_int(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
401401
}
402402
}
403403

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,48 @@ __device__ __forceinline__ void vec_dot_iq2_kt_q8_1(
504504
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
505505
}
506506

507+
__device__ __forceinline__ void vec_dot_iq3_kt_q8_1(
508+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
509+
510+
constexpr uint32_t ka = 0xCBAC1FED;
511+
constexpr uint32_t km = 0x3f3f3f3f;
512+
513+
float scale = *(const float *)vbq;
514+
const block_iq3_kt * bq3 = (const block_iq3_kt *)((const char *)vbq + sizeof(float)) + kbx;
515+
516+
// iqs is 0...28
517+
const int ib32 = iqs/4;
518+
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
519+
const int ls = (bq3->scales[ib32%4] >> 4*(ib32/4)) & 0xf;
520+
const float dl = scale * ls * 1.015f;
521+
auto ql = (const uint16_t *)bq3->ql;
522+
uint32_t mask = 0x01010101 << ib32;
523+
const uint32_t * qh = (const uint32_t *)bq3->qh;
524+
int sumi = 0;
525+
for (int j = 0; j < 4; ++j) {
526+
uint32_t val = ql[4*ib32+j] + 4096;
527+
int v4 = 0;
528+
for (int k = 0; k < 4; ++k) {
529+
val *= ka;
530+
int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126));
531+
v4 |= q << 8*k;
532+
}
533+
uint32_t signs = __vcmpne4(qh[2*j+0] & mask, 0);
534+
v4 = __vsub4(v4 ^ signs, signs);
535+
sumi = ggml_cuda_dp4a(v4, q8[2*j+0], sumi);
536+
v4 = 0;
537+
for (int k = 0; k < 4; ++k) {
538+
val *= ka;
539+
int8_t q = std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126));
540+
v4 |= q << 8*k;
541+
}
542+
signs = __vcmpne4(qh[2*j+1] & mask, 0);
543+
v4 = __vsub4(v4 ^ signs, signs);
544+
sumi = ggml_cuda_dp4a(v4, q8[2*j+1], sumi);
545+
}
546+
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
547+
}
548+
507549
#define VDR_IQ4_KSS_Q8_1_MMVQ 4
508550
#define VDR_IQ4_KSS_Q8_1_MMQ 4
509551

@@ -1304,6 +1346,14 @@ void mul_mat_vec_iq2_kt_q8_1_cuda(
13041346
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
13051347
}
13061348

1349+
void mul_mat_vec_iq3_kt_q8_1_cuda(
1350+
const void * vx, const void * vy, float * dst, const char * ids_data,
1351+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
1352+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
1353+
1354+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq3_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
1355+
}
1356+
13071357
void mul_mat_vec_iq4_kss_q8_1_cuda(
13081358
const void * vx, const void * vy, float * dst, const char * ids_data,
13091359
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,

ggml/src/ggml-cuda/iqk_mmvq.cuh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,17 @@ void mul_mat_vec_iq1_m_r4_q8_1_cuda(
101101
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
102102
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
103103

104-
void mul_mat_vec_iq4_kt_q8_1_cuda(
104+
void mul_mat_vec_iq2_kt_q8_1_cuda(
105105
const void * vx, const void * vy, float * dst, const char * ids_data,
106106
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
107107
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
108108

109-
void mul_mat_vec_iq2_kt_q8_1_cuda(
109+
void mul_mat_vec_iq3_kt_q8_1_cuda(
110+
const void * vx, const void * vy, float * dst, const char * ids_data,
111+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
112+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
113+
114+
void mul_mat_vec_iq4_kt_q8_1_cuda(
110115
const void * vx, const void * vy, float * dst, const char * ids_data,
111116
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
112117
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -527,12 +527,15 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
527527
case GGML_TYPE_IQ4_KSS:
528528
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
529529
break;
530-
case GGML_TYPE_IQ4_KT:
531-
mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
532-
break;
533530
case GGML_TYPE_IQ2_KT:
534531
mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
535532
break;
533+
case GGML_TYPE_IQ3_KT:
534+
mul_mat_vec_iq3_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
535+
break;
536+
case GGML_TYPE_IQ4_KT:
537+
mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
538+
break;
536539
case GGML_TYPE_IQ2_KS:
537540
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
538541
break;
@@ -693,8 +696,9 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
693696
case GGML_TYPE_IQ5_KS_R4:
694697
case GGML_TYPE_IQ1_S_R4:
695698
case GGML_TYPE_IQ1_M_R4:
696-
case GGML_TYPE_IQ4_KT:
697699
case GGML_TYPE_IQ2_KT:
700+
case GGML_TYPE_IQ3_KT:
701+
case GGML_TYPE_IQ4_KT:
698702
return true;
699703
default:
700704
return false;

ggml/src/iqk/iqk_quantize.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7417,7 +7417,11 @@ class QuantizerIQKT {
74177417
for (int k = 0; k < kGroupSize; ++k) {
74187418
x = ka*x;
74197419
s = x & 0x3f3f3f3f;
7420-
result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f);
7420+
if constexpr (is_abs) {
7421+
result[k] = scale*std::abs(i8[0] + i8[1] + i8[2] + i8[3] - 126.f);
7422+
} else {
7423+
result[k] = scale*(i8[0] + i8[1] + i8[2] + i8[3] - 126.f);
7424+
}
74217425
}
74227426
} else {
74237427
constexpr uint32_t ka = 89226354;
@@ -8279,7 +8283,7 @@ void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx
82798283

82808284
namespace {
82818285

8282-
using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>;
8286+
using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true, true>;
82838287
const QuantizerIQ3KT& iq3kt_quantizer() {
82848288
static std::mutex mutex;
82858289
std::lock_guard<std::mutex> lock(mutex);
@@ -8490,7 +8494,7 @@ size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_p
84908494

84918495
void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) {
84928496
#ifdef __AVX2__
8493-
if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return;
8497+
//if (iqk_dequantize_ktquants(GGML_TYPE_IQ3_KT, k, x, 0, y, 0, 1)) return;
84948498
#endif
84958499
using Q = QuantizerIQ3KT;
84968500
constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize;

0 commit comments

Comments
 (0)