@@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_q(
1414 const int64_t src1_padded_row_size, cudaStream_t stream) {
1515
1616 const int64_t ne00 = src0->ne [0 ];
17+ const int64_t nb01 = src0->nb [1 ];
1718
1819 const int64_t ne10 = src1->ne [0 ];
1920 const int64_t ne11 = src1->ne [1 ];
@@ -22,7 +23,6 @@ void ggml_cuda_op_mul_mat_q(
2223 const int64_t ne0 = dst->ne [0 ];
2324
2425 const int64_t row_diff = row_high - row_low;
25- const int64_t stride00 = ne00 / ggml_blck_size (src0->type );
2626
2727 int id = ggml_cuda_get_device ();
2828 const int compute_capability = ggml_cuda_info ().devices [id].cc ;
@@ -31,7 +31,7 @@ void ggml_cuda_op_mul_mat_q(
3131 // nrows_dst == nrows of the matrix that the kernel writes into
3232 const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
3333
34- const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00 , src1_padded_row_size, src1_ncols, ne11, nrows_dst};
34+ const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, nb01 , src1_padded_row_size, src1_ncols, ne11, nrows_dst};
3535
3636 switch (src0->type ) {
3737 case GGML_TYPE_Q4_0:
@@ -91,6 +91,9 @@ void ggml_cuda_op_mul_mat_q(
9191 case GGML_TYPE_IQ4_NL:
9292 mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
9393 break ;
94+ case GGML_TYPE_IQ4_KS:
95+ mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream);
96+ break ;
9497 default :
9598 GGML_ABORT (" fatal error" );
9699 break ;
@@ -128,6 +131,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
128131 case GGML_TYPE_IQ1_S:
129132 case GGML_TYPE_IQ4_XS:
130133 case GGML_TYPE_IQ4_NL:
134+ case GGML_TYPE_IQ4_KS:
131135 mmq_supported = true ;
132136 break ;
133137 default :
0 commit comments