@@ -1082,7 +1082,9 @@ static void ggml_cuda_op_mul_mat_cublas(
10821082
10831083 const int compute_capability = ggml_cuda_info ().devices [id].cc ;
10841084
1085- if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1085+ bool try_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
1086+
1087+ if (compute_capability >= GGML_CUDA_CC_VOLTA && try_fp16) {
10861088 // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
10871089 ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id));
10881090 if (src0->type != GGML_TYPE_F16) {
@@ -1103,28 +1105,38 @@ static void ggml_cuda_op_mul_mat_cublas(
11031105 to_fp16_cuda (src1_ddf_i, src1_as_f16.get (), ne, stream);
11041106 }
11051107 const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get ();
1106- ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id), row_diff*src1_ncols);
11071108
1108- const half alpha_f16 = 1 .0f ;
1109- const half beta_f16 = 0 .0f ;
1109+ CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
11101110
1111- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1112- if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1113- cu_compute_type = CUBLAS_COMPUTE_32F;
1114- }
1111+ if (compute_capability == GGML_CUDA_CC_CDNA) {
1112+ const float alpha = 1 .0f ;
1113+ const float beta = 0 .0f ;
1114+ CUBLAS_CHECK (
1115+ cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1116+ row_diff, src1_ncols, ne10,
1117+ &alpha, src0_ptr, CUDA_R_16F, ne00,
1118+ src1_ptr, CUDA_R_16F, ne10,
1119+ &beta, dst_dd_i, CUDA_R_32F, ldc,
1120+ CUBLAS_COMPUTE_32F,
1121+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1122+ } else {
1123+ ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id), row_diff*src1_ncols);
11151124
1116- CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
1117- CUBLAS_CHECK (
1118- cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1119- row_diff, src1_ncols, ne10,
1120- &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1121- src1_ptr, CUDA_R_16F, ne10,
1122- &beta_f16, dst_f16.get (), CUDA_R_16F, ldc,
1123- cu_compute_type,
1124- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1125+ const half alpha_f16 = 1 .0f ;
1126+ const half beta_f16 = 0 .0f ;
11251127
1126- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1127- to_fp32_cuda (dst_f16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1128+ CUBLAS_CHECK (
1129+ cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1130+ row_diff, src1_ncols, ne10,
1131+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1132+ src1_ptr, CUDA_R_16F, ne10,
1133+ &beta_f16, dst_dd_i, CUDA_R_16F, ldc,
1134+ CUBLAS_COMPUTE_16F,
1135+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1136+
1137+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1138+ to_fp32_cuda (dst_f16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1139+ }
11281140 } else {
11291141 ggml_cuda_pool_alloc<float > src0_ddq_as_f32 (ctx.pool (id));
11301142 ggml_cuda_pool_alloc<float > src1_ddq_as_f32 (ctx.pool (id));
@@ -1613,10 +1625,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
16131625 cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
16141626 cudaDataType_t cu_data_type = CUDA_R_16F;
16151627
1616- if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1617- cu_compute_type = CUBLAS_COMPUTE_32F;
1618- }
1619-
16201628 // dst strides
16211629 size_t nbd2 = dst->nb [2 ];
16221630 size_t nbd3 = dst->nb [3 ];
@@ -1645,6 +1653,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
16451653 beta = &beta_f32;
16461654 }
16471655
1656+ if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1657+ cu_compute_type = CUBLAS_COMPUTE_32F;
1658+ alpha = &alpha_f32;
1659+ beta = &beta_f32;
1660+ }
1661+
16481662 GGML_ASSERT (ne12 % ne02 == 0 );
16491663 GGML_ASSERT (ne13 % ne03 == 0 );
16501664
0 commit comments