Skip to content

Commit 3ae60e7

Browse files
committed
HIP: enable vec fattn on RDNA4
1 parent f4e081c commit 3ae60e7

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
241241
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
242242
return false;
243243
#else
244-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
244+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
246+
return true;
247+
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
248+
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
249+
return true;
250+
#else
251+
return false;
252+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
253+
} else {
254+
return false;
255+
}
246256
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
247257
}
248258

0 commit comments

Comments
 (0)