Skip to content

Commit f2872aa

Browse files
committed
HIP: Avoid fp32->fp16->fp32 conversion on RDNA4
1 parent 2d7a1f9 commit f2872aa

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,7 +1214,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12141214

12151215
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
12161216

1217-
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
1217+
if (GGML_CUDA_CC_IS_CDNA(compute_capability) || GGML_CUDA_CC_IS_RDNA4(compute_capability)) {
12181218
const float alpha = 1.0f;
12191219
const float beta = 0.0f;
12201220
CUBLAS_CHECK(
@@ -1757,10 +1757,16 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17571757
beta = &beta_f32;
17581758
}
17591759

1760-
if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
1760+
const int compute_capability = ggml_cuda_info().devices[ctx.device].cc;
1761+
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
17611762
cu_compute_type = CUBLAS_COMPUTE_32F;
17621763
alpha = &alpha_f32;
17631764
beta = &beta_f32;
1765+
1766+
if (GGML_CUDA_CC_IS_RDNA4(compute_capability)) {
1767+
dst_t = (char *) dst_ddf;
1768+
cu_data_type = CUDA_R_32F;
1769+
}
17641770
}
17651771

17661772
GGML_ASSERT(ne12 % ne02 == 0);
@@ -1834,7 +1840,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18341840
}
18351841
#endif
18361842

1837-
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1843+
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
18381844
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
18391845
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
18401846
}

0 commit comments

Comments
 (0)