@@ -1083,9 +1083,7 @@ static void ggml_cuda_op_mul_mat_cublas(
10831083
10841084 const int compute_capability = ggml_cuda_info ().devices [id].cc ;
10851085
1086- const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
1087-
1088- if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
1086+ if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT) {
10891087 // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
10901088 ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id));
10911089 if (src0->type != GGML_TYPE_F16) {
@@ -1106,38 +1104,28 @@ static void ggml_cuda_op_mul_mat_cublas(
11061104 to_fp16_cuda (src1_ddf_i, src1_as_f16.get (), ne, stream);
11071105 }
11081106 const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get ();
1107+ ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id), row_diff*src1_ncols);
11091108
1110- CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
1109+ const half alpha_f16 = 1 .0f ;
1110+ const half beta_f16 = 0 .0f ;
11111111
1112- if (compute_capability == GGML_CUDA_CC_CDNA) {
1113- const float alpha = 1 .0f ;
1114- const float beta = 0 .0f ;
1115- CUBLAS_CHECK (
1116- cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1117- row_diff, src1_ncols, ne10,
1118- &alpha, src0_ptr, CUDA_R_16F, ne00,
1119- src1_ptr, CUDA_R_16F, ne10,
1120- &beta, dst_dd_i, CUDA_R_32F, ldc,
1121- CUBLAS_COMPUTE_32F,
1122- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1123- } else {
1124- ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id), row_diff*src1_ncols);
1125-
1126- const half alpha_f16 = 1 .0f ;
1127- const half beta_f16 = 0 .0f ;
1112+ cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1113+ if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1114+ cu_compute_type = CUBLAS_COMPUTE_32F;
1115+ }
11281116
1129- CUBLAS_CHECK (
1130- cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1131- row_diff, src1_ncols, ne10,
1132- &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1133- src1_ptr, CUDA_R_16F, ne10,
1134- &beta_f16, dst_dd_i, CUDA_R_16F, ldc,
1135- CUBLAS_COMPUTE_16F,
1136- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1117+ CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
1118+ CUBLAS_CHECK (
1119+ cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1120+ row_diff, src1_ncols, ne10,
1121+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1122+ src1_ptr, CUDA_R_16F, ne10,
1123+ &beta_f16, dst_f16.get (), CUDA_R_16F, ldc,
1124+ cu_compute_type,
1125+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
11371126
1138- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1139- to_fp32_cuda (dst_f16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1140- }
1127+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1128+ to_fp32_cuda (dst_f16.get (), dst_dd_i, row_diff*src1_ncols, stream);
11411129 } else {
11421130 ggml_cuda_pool_alloc<float > src0_ddq_as_f32 (ctx.pool (id));
11431131 ggml_cuda_pool_alloc<float > src1_ddq_as_f32 (ctx.pool (id));
@@ -1626,6 +1614,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
16261614 cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
16271615 cudaDataType_t cu_data_type = CUDA_R_16F;
16281616
1617+ if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1618+ cu_compute_type = CUBLAS_COMPUTE_32F;
1619+ }
1620+
16291621 // dst strides
16301622 size_t nbd2 = dst->nb [2 ];
16311623 size_t nbd3 = dst->nb [3 ];
@@ -1654,12 +1646,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
16541646 beta = &beta_f32;
16551647 }
16561648
1657- if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1658- cu_compute_type = CUBLAS_COMPUTE_32F;
1659- alpha = &alpha_f32;
1660- beta = &beta_f32;
1661- }
1662-
16631649 GGML_ASSERT (ne12 % ne02 == 0 );
16641650 GGML_ASSERT (ne13 % ne03 == 0 );
16651651
0 commit comments