Skip to content

Commit f7c9a0f

Browse files
ikawrakowIwan Kawrakow
andauthored
CUDA: MMQ for IQ4_KS (#374)
* WIP * WIP: still getting illegal memory access * CUDA: MMQ for iq4_ks now works ~25% faster than dequantize+cuBLAS, ~10% slower than Q4_0 MMQ. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 1328128 commit f7c9a0f

File tree

4 files changed

+133
-40
lines changed

4 files changed

+133
-40
lines changed

ggml/src/ggml-cuda/mmq.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_q(
1414
const int64_t src1_padded_row_size, cudaStream_t stream) {
1515

1616
const int64_t ne00 = src0->ne[0];
17+
const int64_t nb01 = src0->nb[1];
1718

1819
const int64_t ne10 = src1->ne[0];
1920
const int64_t ne11 = src1->ne[1];
@@ -22,7 +23,6 @@ void ggml_cuda_op_mul_mat_q(
2223
const int64_t ne0 = dst->ne[0];
2324

2425
const int64_t row_diff = row_high - row_low;
25-
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
2626

2727
int id = ggml_cuda_get_device();
2828
const int compute_capability = ggml_cuda_info().devices[id].cc;
@@ -31,7 +31,7 @@ void ggml_cuda_op_mul_mat_q(
3131
// nrows_dst == nrows of the matrix that the kernel writes into
3232
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
3333

34-
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
34+
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, nb01, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
3535

3636
switch (src0->type) {
3737
case GGML_TYPE_Q4_0:
@@ -91,6 +91,9 @@ void ggml_cuda_op_mul_mat_q(
9191
case GGML_TYPE_IQ4_NL:
9292
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
9393
break;
94+
case GGML_TYPE_IQ4_KS:
95+
mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream);
96+
break;
9497
default:
9598
GGML_ABORT("fatal error");
9699
break;
@@ -128,6 +131,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
128131
case GGML_TYPE_IQ1_S:
129132
case GGML_TYPE_IQ4_XS:
130133
case GGML_TYPE_IQ4_NL:
134+
case GGML_TYPE_IQ4_KS:
131135
mmq_supported = true;
132136
break;
133137
default:

0 commit comments

Comments
 (0)