@@ -6831,7 +6831,7 @@ static void mul_mat_vec_q_cuda(
68316831 const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
68326832
68336833 GGML_ASSERT (ncols_x % qk == 0 );
6834- GGML_ASSERT (ncols_y <= 8 );
6834+ GGML_ASSERT (ncols_y <= 4 );
68356835
68366836 const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
68376837 const dim3 block_nums (block_num_y, 1 , 1 );
@@ -6853,22 +6853,22 @@ static void mul_mat_vec_q_cuda(
68536853 mul_mat_vec_q<4 , qk, qi, block_q_t , vdr, vec_dot>
68546854 <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
68556855 break ;
6856- case 5 :
6857- mul_mat_vec_q<5 , qk, qi, block_q_t , vdr, vec_dot>
6858- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6859- break ;
6860- case 6 :
6861- mul_mat_vec_q<6 , qk, qi, block_q_t , vdr, vec_dot>
6862- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6863- break ;
6864- case 7 :
6865- mul_mat_vec_q<7 , qk, qi, block_q_t , vdr, vec_dot>
6866- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6867- break ;
6868- case 8 :
6869- mul_mat_vec_q<8 , qk, qi, block_q_t , vdr, vec_dot>
6870- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6871- break ;
6856+ // case 5:
6857+ // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
6858+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6859+ // break;
6860+ // case 6:
6861+ // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
6862+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6863+ // break;
6864+ // case 7:
6865+ // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
6866+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6867+ // break;
6868+ // case 8:
6869+ // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
6870+ // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
6871+ // break;
68726872 default :
68736873 GGML_ASSERT (false );
68746874 // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
@@ -9909,7 +9909,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
99099909 ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false );
99109910 }
99119911 } else {
9912- if (src1->ne [1 ] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type )) {
9912+ if (src1->ne [1 ] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized (src0->type )) {
99139913 ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true );
99149914 } else if (use_mul_mat_q) {
99159915 ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_q, true );
0 commit comments