@@ -18,6 +18,7 @@ bool g_mul_mat_q = false;
1818#include " ggml-cuda/cpy.cuh"
1919#include " ggml-cuda/cross-entropy-loss.cuh"
2020#include " ggml-cuda/diagmask.cuh"
21+ #include " ggml-cuda/dmmv.cuh"
2122#include " ggml-cuda/fattn.cuh"
2223#include " ggml-cuda/getrows.cuh"
2324#include " ggml-cuda/im2col.cuh"
@@ -1025,6 +1026,114 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
10251026
10261027#define MUL_MAT_SRC1_COL_STRIDE 128
10271028
1029+ static __global__ void mul_mat_p021_f16_f32 (
1030+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1031+ const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
1032+
1033+ const half * x = (const half *) vx;
1034+
1035+ const int row_x = blockDim .y *blockIdx .y + threadIdx .y ;
1036+ const int channel = blockDim .z *blockIdx .z + threadIdx .z ;
1037+ const int channel_x = channel / (nchannels_y / nchannels_x);
1038+
1039+ const int nrows_y = ncols_x;
1040+ const int nrows_dst = nrows_x;
1041+ const int row_dst = row_x;
1042+
1043+ float tmp = 0 .0f ;
1044+
1045+ for (int col_x0 = 0 ; col_x0 < ncols_x; col_x0 += blockDim .x ) {
1046+ const int col_x = col_x0 + threadIdx .x ;
1047+
1048+ if (col_x >= ncols_x) {
1049+ break ;
1050+ }
1051+
1052+ // x is transposed and permuted
1053+ const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
1054+ const float xi = __half2float (x[ix]);
1055+
1056+ const int row_y = col_x;
1057+
1058+ // y is not transposed but permuted
1059+ const int iy = channel*nrows_y + row_y;
1060+
1061+ tmp += xi * y[iy];
1062+ }
1063+
1064+ // dst is not transposed and not permuted
1065+ const int idst = channel*nrows_dst + row_dst;
1066+
1067+ // sum up partial sums and write back result
1068+ tmp = warp_reduce_sum (tmp);
1069+
1070+ if (threadIdx .x == 0 ) {
1071+ dst[idst] = tmp;
1072+ }
1073+ }
1074+
1075+ static __global__ void mul_mat_vec_nc_f16_f32 ( // nc == non-contiguous
1076+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1077+ const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
1078+
1079+ const half * x = (const half *) vx;
1080+
1081+ const int row_x = blockDim .y *blockIdx .y + threadIdx .y ;
1082+ const int channel = blockDim .z *blockIdx .z + threadIdx .z ;
1083+ const int channel_x = channel / channel_x_divisor;
1084+
1085+ const int nrows_y = ncols_x;
1086+ const int nrows_dst = nrows_x;
1087+ const int row_dst = row_x;
1088+
1089+ const int idst = channel*nrows_dst + row_dst;
1090+
1091+ float tmp = 0 .0f ;
1092+
1093+ for (int col_x0 = 0 ; col_x0 < ncols_x; col_x0 += blockDim .x ) {
1094+ const int col_x = col_x0 + threadIdx .x ;
1095+
1096+ if (col_x >= ncols_x) {
1097+ break ;
1098+ }
1099+
1100+ const int row_y = col_x;
1101+
1102+ const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1103+ const int iy = channel*nrows_y + row_y;
1104+
1105+ const float xi = __half2float (x[ix]);
1106+
1107+ tmp += xi * y[iy];
1108+ }
1109+
1110+ // sum up partial sums and write back result
1111+ tmp = warp_reduce_sum (tmp);
1112+
1113+ if (threadIdx .x == 0 ) {
1114+ dst[idst] = tmp;
1115+ }
1116+ }
1117+
1118+ static void ggml_mul_mat_p021_f16_f32_cuda (
1119+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
1120+ const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
1121+
1122+ const dim3 block_nums (1 , nrows_x, nchannels_y);
1123+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
1124+ mul_mat_p021_f16_f32<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
1125+ }
1126+
1127+ static void ggml_mul_mat_vec_nc_f16_f32_cuda (
1128+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
1129+ const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
1130+
1131+ const dim3 block_nums (1 , nrows_x, nchannels_y);
1132+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
1133+ mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0 , stream>>>
1134+ (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
1135+ }
1136+
10281137static cudaError_t ggml_cuda_cpy_tensor_2d (
10291138 void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
10301139
@@ -1593,6 +1702,58 @@ static void ggml_cuda_op_mul_mat(
15931702 }
15941703}
15951704
1705+ static void ggml_cuda_mul_mat_vec_p021 (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1706+ GGML_ASSERT (ggml_is_permuted (src0) && ggml_is_permuted (src1));
1707+ GGML_ASSERT (ggml_backend_buffer_is_cuda (src0->buffer ));
1708+ GGML_ASSERT (src0->nb [0 ] <= src0->nb [1 ] && src0->nb [2 ] <= src0->nb [3 ]); // 0213 permutation
1709+ GGML_ASSERT (src1->nb [0 ] <= src1->nb [1 ] && src1->nb [2 ] <= src1->nb [3 ]); // 0213 permutation
1710+ GGML_ASSERT (src0->type == GGML_TYPE_F16);
1711+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
1712+
1713+ const int64_t ne00 = src0->ne [0 ];
1714+ const int64_t ne01 = src0->ne [1 ];
1715+ const int64_t ne02 = src0->ne [2 ];
1716+
1717+ const int64_t ne12 = src1->ne [2 ];
1718+
1719+ cudaStream_t main_stream = ctx.stream ();
1720+
1721+ void * src0_ddq = src0->data ;
1722+ float * src1_ddf = (float *) src1->data ;
1723+ float * dst_ddf = (float *) dst->data ;
1724+
1725+ ggml_mul_mat_p021_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
1726+ }
1727+
1728+ static void ggml_cuda_mul_mat_vec_nc (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1729+ GGML_ASSERT (!ggml_is_transposed (src0));
1730+ GGML_ASSERT (!ggml_is_transposed (src1));
1731+ GGML_ASSERT (!ggml_is_permuted (src0));
1732+ GGML_ASSERT (ggml_backend_buffer_is_cuda (src0->buffer ));
1733+ GGML_ASSERT (src0->type == GGML_TYPE_F16);
1734+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
1735+
1736+ const int64_t ne00 = src0->ne [0 ];
1737+ const int64_t ne01 = src0->ne [1 ];
1738+ const int64_t ne02 = src0->ne [2 ];
1739+
1740+ const int64_t nb01 = src0->nb [1 ];
1741+ const int64_t nb02 = src0->nb [2 ];
1742+
1743+ const int64_t ne12 = src1->ne [2 ];
1744+
1745+ cudaStream_t main_stream = ctx.stream ();
1746+
1747+ void * src0_ddq = src0->data ;
1748+ float * src1_ddf = (float *) src1->data ;
1749+ float * dst_ddf = (float *) dst->data ;
1750+
1751+ const int64_t row_stride_x = nb01 / sizeof (half);
1752+ const int64_t channel_stride_x = nb02 / sizeof (half);
1753+
1754+ ggml_mul_mat_vec_nc_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
1755+ }
1756+
15961757static __global__ void k_compute_batched_ptrs (
15971758 const half * src0_as_f16, const half * src1_as_f16, char * dst,
15981759 const void ** ptrs_src, void ** ptrs_dst,
@@ -1770,18 +1931,27 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17701931static void ggml_cuda_mul_mat (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
17711932 const bool split = ggml_backend_buft_is_cuda_split (src0->buffer ->buft );
17721933
1934+ bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported (src0->type )
1935+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1936+ && src0->ne [0 ] % (GGML_CUDA_DMMV_X*2 ) == 0 && src1->ne [1 ] == 1 ;
1937+
17731938 bool use_mul_mat_vec = src0->type == GGML_TYPE_F16
17741939 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
17751940 && src0->ne [0 ] % 2 == 0 && src1->ne [1 ] == 1 ;
1776- bool use_mul_mat_vec_q = ggml_is_quantized (src0->type )
1777- // && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
1778- // bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
1779- // && ggml_cuda_mmvq_type_supported(src0->type)
1941+
1942+ bool use_mul_mat_vec_q = ggml_is_quantized (src0->type )
1943+ && ggml_cuda_mmvq_type_supported (src0->type )
17801944 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
17811945 && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
1946+
17821947 bool use_mul_mat_q = ggml_is_quantized (src0->type )
17831948 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
17841949
1950+ // if mmvq is available it's a better choice than dmmv:
1951+ #ifndef GGML_CUDA_FORCE_DMMV
1952+ use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
1953+ #endif // GGML_CUDA_FORCE_DMMV
1954+
17851955 bool any_gpus_with_slow_fp16 = false ;
17861956 bool any_gpus_without_fp16_mma = false ;
17871957
@@ -1814,14 +1984,28 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
18141984 // printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
18151985 // printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
18161986
1817- if (!split && use_mul_mat_vec && dst->ne [3 ] == 1 && (src0->ne [1 ] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
1987+ if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
1988+ // FP32 precision KQ single-batch for batch size 1 without FlashAttention
1989+ ggml_cuda_mul_mat_vec_p021 (ctx, src0, src1, dst);
1990+
1991+ } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous (src0) && !ggml_is_transposed (src1) && src1->ne [1 ] == 1 ) {
1992+ // FP32 precision KQV single-batch for batch size 1 without FlashAttention
1993+ ggml_cuda_mul_mat_vec_nc (ctx, src0, src1, dst);
1994+
1995+ } else if (!split && use_mul_mat_vec && dst->ne [3 ] == 1 && (src0->ne [1 ] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
18181996 // the custom F16 vector kernel can be used over batched cuBLAS GEMM
18191997 // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
18201998 ggml_cuda_mul_mat_vec (ctx, src0, src1, dst);
1999+
2000+ } else if (!split && src0->type == GGML_TYPE_F16 && src1->ne [1 ] == 1 && dst->ne [3 ] == 1 && (src0->ne [1 ] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
2001+ ggml_cuda_mul_mat_vec (ctx, src0, src1, dst);
2002+
18212003 } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
18222004 && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
18232005 // general KQ + KQV multi-batch without FlashAttention
18242006 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
2007+ } else if (use_dequantize_mul_mat_vec) {
2008+ ggml_cuda_op_mul_mat (ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr );
18252009 } else if (use_mul_mat_vec) {
18262010 ggml_cuda_op_mul_mat (ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr );
18272011 } else if (use_mul_mat_vec_q) {
0 commit comments