Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KT> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

//////////////////////

struct ggml_cuda_device_info {
Expand Down
16 changes: 11 additions & 5 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,13 @@ inline __device__ int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}

int __device__ __forceinline__ trellis_next_int(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
val = ka*val + kb;
return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126);
}

float __device__ __forceinline__ trellis_next(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
Expand Down Expand Up @@ -401,9 +408,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const float * dptr = (const float *)((const char *)vx + row * row_size);
float scale = dptr[0] * 31.75f * 1.01f;
float row_av = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
float scale = dptr[0] * 1.00f;
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1);
const int64_t i = ii - (row*n_per_row)/QK_K;

constexpr int kNumGroups = 64;
Expand All @@ -423,8 +429,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
const float dl = scale * ls;
for (int j = 0; j < 4; ++j) {
y[j+0] = dl * trellis_next(idx1) + row_av;
y[j+4] = dl * trellis_next(idx2) + row_av;
y[j+0] = dl * trellis_next_int(idx1);
y[j+4] = dl * trellis_next_int(idx2);
}
}

Expand Down
46 changes: 46 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,44 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}

__device__ __forceinline__ void vec_dot_iq4_kt_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {

constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t km = 0x3f3f3f3f;

float scale = *(const float *)vbq;
const block_iq4_kt * bq4 = (const block_iq4_kt *)((const char *)vbq + sizeof(float)) + kbx;

// iqs is 0...28
const int ib32 = iqs/4; // Why iqs/4 ?
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
//const int8_t * q8 = bq8_1[ib32].qs;
const int ls = (bq4->qs[ib32] & 0xff) >> 1;
const float dl = scale * (ls - 64);
const uint32_t idx0 = ((bq4->qs[ib32] & 1) << 15) + 4096;
auto ql = (const uint8_t *)(bq4->qs + 8);
auto qh = ql + 64;
ql += 8*ib32;
qh += 8*(ib32%4);
const int shift1 = 8 - 4*(ib32/4);
int sumi = 0;
for (int j = 0; j < 8; ++j) {
const uint32_t sh = bq4->qs[ib32] >> (8 + 3*j);
uint32_t val = ql[j] + ((qh[j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0;
int v4 = 0;
for (int k = 0; k < 4; ++k) {
val = ka*val + kb;
//int s = val & km;
//sumi += q8[4*j+k] * ggml_cuda_dp4a(s, 0x01010101, -126);
v4 |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k;
}
sumi = ggml_cuda_dp4a(v4, q8[j], sumi);
}
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}

#define VDR_IQ4_KSS_Q8_1_MMVQ 4
#define VDR_IQ4_KSS_Q8_1_MMQ 4

Expand Down Expand Up @@ -1171,6 +1209,14 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}

void mul_mat_vec_iq4_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {

iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KT, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_kt_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}

void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ void mul_mat_vec_iq1_s_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);

void mul_mat_vec_iq4_kt_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ4_KS_R4:
mul_mat_q_case<GGML_TYPE_IQ4_KS_R4>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_KT:
mul_mat_q_case<GGML_TYPE_IQ4_KT>(ctx, args, stream);
break;
case GGML_TYPE_IQ5_KS:
mul_mat_q_case<GGML_TYPE_IQ5_KS>(ctx, args, stream);
break;
Expand Down Expand Up @@ -172,6 +175,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ4_KT:
mmq_supported = true;
break;
default:
Expand Down
84 changes: 84 additions & 0 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ4_KT:
return MMQ_Q8_1_DS_LAYOUT_D4;
default:
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -202,6 +203,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0;
default : return tile_x_sizes{0, 0, 0};
}
}
Expand Down Expand Up @@ -250,6 +252,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0;
default : return 0;
}
}
Expand Down Expand Up @@ -2790,6 +2793,79 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

}

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_kt(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t km = 0x3f3f3f3f;

#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int kqsx = threadIdx.x;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;

if (need_check) {
i = min(i, i_max);
}

const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + sizeof(float)) + kbx0;

int ib32 = kqsx/4;
int j = kqsx%4;
const auto shb = bxi->qs;
const auto ql = (const uint8_t *)(shb + 8);
const auto qh = ql + 64;
const uint32_t sh = shb[ib32] >> (8 + 6*j);
uint32_t offset = 4096 + ((shb[ib32] & 1) << 15);
uint32_t val1 = offset + ql[8*ib32+2*j+0] + ((qh[8*(ib32%4)+2*j+0] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 7) << 12);
uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9);
int2 v = {0, 0};
for (int k = 0; k < 4; ++k) {
val1 = ka*val1 + kb;
val2 = ka*val2 + kb;
v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k;
v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k;
}
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y;
#endif // INT8_MMA_AVAILABLE
}

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);

if (need_check) {
i = min(i, i_max);
}

const float * dptr = (const float *)(x + i*stride);
const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 1) + kbx0;
const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1;

#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * (ls - 64);
#else
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * (ls - 64);
#endif // INT8_MMA_AVAILABLE
}
}

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq5_ks_r4(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

Expand Down Expand Up @@ -3382,6 +3458,13 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KS_R4> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KT> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kt<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ5_KS> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks<mmq_y, nwarps, need_check>;
Expand Down Expand Up @@ -3843,6 +3926,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);

// -------------------------------------------------------------------------------------------------------------------------

Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
break;
case GGML_TYPE_IQ4_KSS:
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
case GGML_TYPE_IQ4_KT:
mul_mat_vec_iq4_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
break;
case GGML_TYPE_IQ2_KS:
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
Expand Down Expand Up @@ -683,6 +686,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ5_KS_R4:
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ4_KT:
return true;
default:
return false;
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq.cuh"

DECL_MMQ_CASE(GGML_TYPE_IQ4_KT);
44 changes: 31 additions & 13 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6596,6 +6596,25 @@ void kernel_mul_mv_iq2_k_f32_impl(
}
}

struct Trellis3 {
constexpr constant static uint32_t kmask = 0x3f3f3f3f;
constexpr constant static uint32_t ka = 89226354;
constexpr constant static uint32_t kb = 64248484;
constexpr constant static uint32_t ka1 = ka*ka;
constexpr constant static uint32_t kb1 = kb*ka+kb;
constexpr constant static uint32_t ka2 = ka1*ka;
constexpr constant static uint32_t kb2 = kb1*ka+kb;
constexpr constant static uint32_t ka3 = ka2*ka;
constexpr constant static uint32_t kb3 = kb2*ka+kb;
static inline char4 gen4(uint32_t val) {
thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask};
thread const int8_t * a8 = (thread const int8_t *)aux;
char4 result;
for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3];
return result;
}
};

struct Trellis {
constexpr constant static uint32_t kmask1 = 0x8fff8fff;
constexpr constant static uint32_t kmask2 = 0x3b603b60;
Expand Down Expand Up @@ -8586,20 +8605,20 @@ void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread
device const uint32_t * shb = x->qs;
device const uint8_t * ql = (device const uint8_t *)(shb + 8);
device const uint8_t * qh = ql + 64;
float scale = d * (((shb[ib32] & 0xff) >> 1) - 64);
const int ls = (shb[ib32] & 0xff) >> 1;
const float scale = d * (ls - 64);
const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15);

const int jj = ib32*8 + 4*(il%2);
ql += jj;
qh += jj%32;
ql += 8*ib32;
qh += 8*(ib32%4);

uint32_t sh = (shb[ib32] >> (8 + 12*(il%2))) << 12;
const int shift = 8 - 4*(jj/32);
const int shift = 8 - 4*(ib32/4);

for (int i = 0; i < 4; ++i) {
uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset;
auto v = (float4)Trellis::gen4(idx);
reg[i] = v * scale;
auto c4 = Trellis3::gen4(idx);
reg[i] = {scale*c4[0], scale*c4[1], scale*c4[2], scale*c4[3]};
}
}

Expand Down Expand Up @@ -8931,18 +8950,17 @@ struct DequantizerKT4 {
using type4x4 = T4x4;
DequantizerKT4(device const char * cx, short il = 0) : il(il) {
device const float * dptr = (device const float *)cx;
d[0] = dptr[0] * 31.75f * 1.01f;
d[1] = dptr[1];
x = (device const Block *)(dptr + 2);
d = dptr[0] * 1.01f;
x = (device const Block *)(dptr + 1);
}
inline void convert(thread T4x4& t) const {
float4x4 tmp;
dequantize_iq4_kt(x, il, d[0], tmp);
dequantize_iq4_kt(x, il, d, tmp);
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j];
}
inline void convert(int64_t ind, thread T4x4& t) {
float4x4 tmp;
dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp);
dequantize_iq4_kt(x + ind/nl, ind%nl, d, tmp);
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j];
}
inline void next() {
Expand All @@ -8951,7 +8969,7 @@ struct DequantizerKT4 {
}
device const Block * x;
short il;
float d[2];
float d;
};

template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(half d, device const Block *, short, thread T4x4&), bool may_not_be_aligned = false>
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1617,13 +1617,13 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq4_kt,
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
.vec_dot = vec_dot_iq4_kt_q8_k,
#ifdef __ARM_NEON
.vec_dot_type = GGML_TYPE_F16,
#if defined __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_F32,
.vec_dot_type = GGML_TYPE_Q8_0_X4,
#endif
.nrows = 1,
.row_meta_size = 8,
.row_meta_size = 4,
},
[GGML_TYPE_IQ3_K] = {
.type_name = "iq3_k",
Expand Down
Loading