@@ -5313,7 +5313,7 @@ template <bool need_check> static __global__ void
53135313template <int ncols_y_template, int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
53145314static __global__ void mul_mat_vec_q (
53155315 const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
5316- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
5316+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst ) {
53175317
53185318 const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
53195319
@@ -5352,7 +5352,7 @@ static __global__ void mul_mat_vec_q(
53525352 tmp[j] = warp_reduce_sum (tmp[j]);
53535353
53545354 if (threadIdx .x == 0 ) {
5355- dst[j*nrows_x + row] = tmp[j];
5355+ dst[j*nrows_dst + row] = tmp[j];
53565356 }
53575357 }
53585358}
@@ -6828,7 +6828,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
68286828template <int qk, int qi, typename block_q_t , int vdr, vec_dot_q_cuda_t vec_dot>
68296829static void mul_mat_vec_q_cuda (
68306830 const void * vx, const void * vy, float * dst,
6831- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
6831+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
68326832
68336833 GGML_ASSERT (ncols_x % qk == 0 );
68346834 GGML_ASSERT (ncols_y <= 4 );
@@ -6839,40 +6839,40 @@ static void mul_mat_vec_q_cuda(
68396839 switch (ncols_y) {
68406840 case 1 :
68416841 mul_mat_vec_q<1 , qk, qi, block_q_t , vdr, vec_dot>
6842- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6842+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68436843 break ;
68446844 case 2 :
68456845 mul_mat_vec_q<2 , qk, qi, block_q_t , vdr, vec_dot>
6846- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6846+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68476847 break ;
68486848 case 3 :
68496849 mul_mat_vec_q<3 , qk, qi, block_q_t , vdr, vec_dot>
6850- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6850+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68516851 break ;
68526852 case 4 :
68536853 mul_mat_vec_q<4 , qk, qi, block_q_t , vdr, vec_dot>
6854- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6854+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68556855 break ;
68566856 // case 5:
68576857 // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6858+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68596859 // break;
68606860 // case 6:
68616861 // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6862+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68636863 // break;
68646864 // case 7:
68656865 // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6866+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68676867 // break;
68686868 // case 8:
68696869 // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6870+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68716871 // break;
68726872 default :
68736873 GGML_ASSERT (false );
68746874 // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6875- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6875+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst );
68766876 break ;
68776877 }
68786878}
@@ -8391,7 +8391,7 @@ static void ggml_cuda_op_mul_mat_q(
83918391 CUDA_CHECK (cudaGetDevice (&id));
83928392
83938393 // the main device has a larger memory buffer to hold the results from all GPUs
8394- // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
8394+ // nrows_dst == nrows of the matrix that the kernel writes into
83958395 const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
83968396
83978397 switch (src0->type ) {
@@ -8525,58 +8525,70 @@ static void ggml_cuda_op_mul_mat_vec_q(
85258525 const int64_t ne00 = src0->ne [0 ];
85268526 const int64_t row_diff = row_high - row_low;
85278527
8528+ const int64_t ne10 = src1->ne [0 ];
8529+ GGML_ASSERT (ne10 % QK8_1 == 0 );
8530+
8531+ const int64_t ne0 = dst->ne [0 ];
8532+
8533+ int id;
8534+ CUDA_CHECK (cudaGetDevice (&id));
8535+
8536+ // the main device has a larger memory buffer to hold the results from all GPUs
8537+ // nrows_dst == nrows of the matrix that the kernel writes into
8538+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
8539+
85288540 switch (src0->type ) {
85298541 case GGML_TYPE_Q4_0:
85308542 mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
8531- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8543+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85328544 break ;
85338545 case GGML_TYPE_Q4_1:
85348546 mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
8535- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8547+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85368548 break ;
85378549 case GGML_TYPE_Q5_0:
85388550 mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
8539- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8551+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85408552 break ;
85418553 case GGML_TYPE_Q5_1:
85428554 mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
8543- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8555+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85448556 break ;
85458557 case GGML_TYPE_Q8_0:
85468558 mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
8547- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8559+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85488560 break ;
85498561 case GGML_TYPE_Q2_K:
85508562 mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
8551- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8563+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85528564 break ;
85538565 case GGML_TYPE_Q3_K:
85548566 mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
8555- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8567+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85568568 break ;
85578569 case GGML_TYPE_Q4_K:
85588570 mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
8559- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8571+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85608572 break ;
85618573 case GGML_TYPE_Q5_K:
85628574 mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
8563- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8575+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85648576 break ;
85658577 case GGML_TYPE_Q6_K:
85668578 mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
8567- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8579+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85688580 break ;
85698581 case GGML_TYPE_IQ2_XXS:
85708582 mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1 , vec_dot_iq2_xxs_q8_1>
8571- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8583+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85728584 break ;
85738585 case GGML_TYPE_IQ2_XS:
85748586 mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1 , vec_dot_iq2_xs_q8_1>
8575- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8587+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85768588 break ;
85778589 case GGML_TYPE_IQ3_XXS:
85788590 mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1 , vec_dot_iq3_xxs_q8_1>
8579- (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
8591+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
85808592 break ;
85818593 default :
85828594 GGML_ASSERT (false );
@@ -9909,7 +9921,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
99099921 ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false );
99109922 }
99119923 } else {
9912- if (src1->ne [1 ] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type )) {
9924+ if (src1->ne [1 ] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type ) && src1-> type == GGML_TYPE_F32 ) {
99139925 ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true );
99149926 } else if (use_mul_mat_q) {
99159927 ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_q, true );
0 commit comments