@@ -1107,14 +1107,19 @@ static void ggml_cuda_op_mul_mat_cublas(
11071107 const half alpha_f16 = 1 .0f ;
11081108 const half beta_f16 = 0 .0f ;
11091109
1110+ cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1111+ if (ggml_cuda_info ().devices [ctx.device ].cc == CC_CDNA) {
1112+ cu_compute_type = CUBLAS_COMPUTE_32F;
1113+ }
1114+
11101115 CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
11111116 CUBLAS_CHECK (
11121117 cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
11131118 row_diff, src1_ncols, ne10,
11141119 &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
11151120 src1_ptr, CUDA_R_16F, ne10,
11161121 &beta_f16, dst_f16.get (), CUDA_R_16F, ldc,
1117- CUBLAS_COMPUTE_16F ,
1122+ cu_compute_type ,
11181123 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
11191124
11201125 const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
@@ -1607,6 +1612,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
16071612 cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
16081613 cudaDataType_t cu_data_type = CUDA_R_16F;
16091614
1615+ if (ggml_cuda_info ().devices [ctx.device ].cc == CC_CDNA) {
1616+ cu_compute_type = CUBLAS_COMPUTE_32F;
1617+ }
1618+
16101619 // dst strides
16111620 size_t nbd2 = dst->nb [2 ];
16121621 size_t nbd3 = dst->nb [3 ];
0 commit comments