@@ -1198,9 +1198,14 @@ static void ggml_cuda_op_mul_mat_cublas(
11981198
11991199 const int cc = ggml_cuda_info ().devices [id].cc ;
12001200
1201+ const bool support_bf16 = GGML_CUDA_CC_IS_NVIDIA (cc) || GGML_CUDA_CC_IS_AMD (cc) ||
1202+ (GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
1203+
1204+ const bool support_fp16 = (GGML_CUDA_CC_IS_NVIDIA (cc) && cc >= GGML_CUDA_CC_VOLTA) ||
1205+ GGML_CUDA_CC_IS_AMD (cc) || (GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2);
12011206 const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
12021207
1203- if (( GGML_CUDA_CC_IS_NVIDIA (cc) || GGML_CUDA_CC_IS_AMD (cc) || ( GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2)) && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
1208+ if (support_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
12041209 ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16 (ctx.pool (id));
12051210 if (src1->type != GGML_TYPE_BF16) {
12061211 const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda (src1->type );
@@ -1228,7 +1233,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12281233
12291234 const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_BF16);
12301235 to_fp32_cuda (dst_bf16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1231- } else if ((( GGML_CUDA_CC_IS_NVIDIA (cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD (cc) || ( GGML_CUDA_CC_IS_MTHREADS (cc) && cc >= GGML_CUDA_CC_QY2)) && use_fp16) {
1236+ } else if (support_fp16 && use_fp16) {
12321237 // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12331238 ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id));
12341239 if (src0->type != GGML_TYPE_F16) {
0 commit comments