@@ -500,7 +500,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
500500}
501501
502502static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
503- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
503+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
504504 const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
505505 // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
506506 const dim3 block_nums (block_num_y, 1 , 1 );
@@ -510,7 +510,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
510510}
511511
512512static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
513- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
513+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
514514 const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
515515 const dim3 block_nums (block_num_y, 1 , 1 );
516516 const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -519,7 +519,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
519519}
520520
521521static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
522- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
522+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
523523 const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
524524 const dim3 block_nums (block_num_y, 1 , 1 );
525525 const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -528,7 +528,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
528528}
529529
530530static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
531- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
531+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
532532 const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
533533 const dim3 block_nums (block_num_y, 1 , 1 );
534534 const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -537,7 +537,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
537537}
538538
539539static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
540- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
540+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
541541 const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
542542 const dim3 block_nums (block_num_y, 1 , 1 );
543543 const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -588,7 +588,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
588588}
589589
590590static void convert_mul_mat_vec_f16_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
591- GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
591+ GGML_ASSERT (ncols % ( GGML_CUDA_DMMV_X* 2 ) == 0 );
592592 const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
593593 const dim3 block_nums (block_num_y, 1 , 1 );
594594 const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
@@ -672,3 +672,12 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
672672 GGML_UNUSED (src1_ncols);
673673 GGML_UNUSED (src1_padded_row_size);
674674}
675+
676+ bool ggml_cuda_dmmv_type_supported (ggml_type src0_type) {
677+ return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
678+ src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
679+ src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
680+ src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
681+ src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
682+ src0_type == GGML_TYPE_F16;
683+ }
0 commit comments