22
22
#include " ggml-cuda/fattn.cuh"
23
23
#include " ggml-cuda/getrows.cuh"
24
24
#include " ggml-cuda/im2col.cuh"
25
+ #include " ggml-cuda/mmf.cuh"
25
26
#include " ggml-cuda/mmq.cuh"
26
- #include " ggml-cuda/mmv .cuh"
27
+ #include " ggml-cuda/mmvf .cuh"
27
28
#include " ggml-cuda/mmvq.cuh"
28
29
#include " ggml-cuda/norm.cuh"
29
30
#include " ggml-cuda/opt-step-adamw.cuh"
@@ -2008,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2008
2009
const bool bad_padding_clear = ggml_backend_buffer_get_usage (src0->buffer ) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
2009
2010
&& ggml_nbytes (src0) != ggml_backend_buffer_get_alloc_size (src0->buffer , src0) && src0->view_src ;
2010
2011
2011
- bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2012
+ bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2013
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2014
+ bool use_mul_mat_f = !ggml_is_quantized (src0->type )
2012
2015
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2013
2016
bool use_mul_mat_vec_q = ggml_is_quantized (src0->type ) && !bad_padding_clear
2014
2017
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2028,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2028
2031
}
2029
2032
2030
2033
const int cc = ggml_cuda_info ().devices [id].cc ;
2034
+ const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
2031
2035
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
2032
- use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv (src0->type , cc, src0->ne , src1->ne [1 ]);
2036
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf (src0->type , cc, warp_size, src0->ne , src1->ne [1 ]);
2037
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf (src0->type , cc, src0->ne , src1->ne [1 ]);
2033
2038
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
2034
2039
}
2035
2040
} else {
2036
2041
const int cc = ggml_cuda_info ().devices [ctx.device ].cc ;
2042
+ const int warp_size = ggml_cuda_info ().devices [ctx.device ].warp_size ;
2037
2043
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
2038
- use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv (src0->type , cc, src0->ne , src1->ne [1 ]);
2044
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf (src0->type , cc, warp_size, src0->ne , src1->ne [1 ]);
2045
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf (src0->type , cc, src0->ne , src1->ne [1 ]);
2039
2046
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available (cc);
2040
2047
}
2041
2048
@@ -2048,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2048
2055
// printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2049
2056
2050
2057
// TODO update for generic tensor parallelism
2051
- const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2058
+ const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2052
2059
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2053
2060
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2054
2061
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2055
2062
2056
- if (!split && use_mul_mat_vec ) {
2063
+ if (!split && use_mul_mat_vec_f ) {
2057
2064
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
2058
2065
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
2059
- ggml_cuda_mul_mat_vec (ctx, src0, src1, nullptr , dst);
2066
+ ggml_cuda_mul_mat_vec_f (ctx, src0, src1, nullptr , dst);
2067
+ } else if (!split && use_mul_mat_f) {
2068
+ ggml_cuda_mul_mat_f (ctx, src0, src1, nullptr , dst);
2060
2069
} else if (!split && use_mul_mat_vec_q) {
2061
2070
ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
2062
2071
} else if (!split && use_mul_mat_q) {
@@ -2065,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2065
2074
&& !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
2066
2075
// general KQ + KQV multi-batch without FlashAttention
2067
2076
ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
2068
- } else if (use_mul_mat_vec ) {
2069
- ggml_cuda_op_mul_mat (ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec , nullptr );
2077
+ } else if (use_mul_mat_vec_f ) {
2078
+ ggml_cuda_op_mul_mat (ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f , nullptr );
2070
2079
} else if (use_mul_mat_vec_q) {
2071
2080
ggml_cuda_op_mul_mat (ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
2072
2081
} else if (use_mul_mat_q) {
@@ -2094,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2094
2103
if (ggml_is_quantized (src0->type )) {
2095
2104
ggml_cuda_mul_mat_vec_q (ctx, src0, src1, ids, dst);
2096
2105
} else {
2097
- ggml_cuda_mul_mat_vec (ctx, src0, src1, ids, dst);
2106
+ ggml_cuda_mul_mat_vec_f (ctx, src0, src1, ids, dst);
2098
2107
}
2099
2108
return ;
2100
2109
}
@@ -3516,7 +3525,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3516
3525
#endif // FLASH_ATTN_AVAILABLE
3517
3526
if (op->src [1 ]->ne [0 ] != op->src [2 ]->ne [0 ]) {
3518
3527
const int cc = ggml_cuda_info ().devices [dev_ctx->device ].cc ;
3519
- if (!new_mma_available (cc)) {
3528
+ if (!turing_mma_available (cc)) {
3520
3529
return false ;
3521
3530
}
3522
3531
const int gqa_ratio = op->src [0 ]->ne [2 ] / op->src [1 ]->ne [2 ];
0 commit comments