Skip to content

Commit e95fec6

Browse files
authored
HIP: Disable ROCWMMA fattn on CDNA when compiled against ROCWMMA 2.0.0 (ggml-org#16221)
* HIP: Disable ROCWMMA fatt on CDNA when compiled against ROCWMMA 2.0.0 rocwmma 2.0.0 includes a bug in the code fakeing fp16 accumulation on CDNA * CUDA: Fix volta condition in ggml_cuda_should_use_wmma_fattn
1 parent ded67b9 commit e95fec6

File tree

8 files changed

+61
-50
lines changed

8 files changed

+61
-50
lines changed

ggml/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ option(GGML_HIP "ggml: use HIP"
209209
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
210210
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
211211
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
212-
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
213212
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
214213
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
215214
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,6 @@ static const char * cu_get_error_str(CUresult err) {
220220
#define FAST_FP16_AVAILABLE
221221
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
222222

223-
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
224-
#define FP16_MMA_AVAILABLE
225-
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
226-
227-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
228-
#define FP16_MMA_AVAILABLE
229-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
230-
231223
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
232224
#define AMD_MFMA_AVAILABLE
233225
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
@@ -262,27 +254,6 @@ static bool fast_fp16_hardware_available(const int cc) {
262254
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
263255
}
264256

265-
// Any FP16 tensor core instructions are available for ggml code.
266-
static bool fp16_mma_available(const int cc) {
267-
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
268-
return false;
269-
#else
270-
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
271-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
272-
GGML_CUDA_CC_IS_MTHREADS(cc)) {
273-
return true;
274-
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
275-
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
276-
return true;
277-
#else
278-
return false;
279-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
280-
} else {
281-
return false;
282-
}
283-
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
284-
}
285-
286257
// To be used for feature selection of external libraries, e.g. cuBLAS.
287258
static bool fp16_mma_hardware_available(const int cc) {
288259
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "common.cuh"
22
#include "fattn-common.cuh"
33
#include "fattn-tile.cuh"
4+
#include "fattn-wmma-f16.cuh"
45

56
// kq_stride == number of KQ rows to process per iteration
67
// kq_nbatch == number of K columns to load in parallel for KQ calculation
@@ -190,10 +191,10 @@ static __global__ void flash_attn_tile(
190191
#ifdef FLASH_ATTN_AVAILABLE
191192

192193
// Skip unused kernel variants for faster compilation:
193-
#ifdef FP16_MMA_AVAILABLE
194+
#ifdef GGML_USE_WMMA_FATTN
194195
NO_DEVICE_CODE;
195196
return;
196-
#endif // FP16_MMA_AVAILABLE
197+
#endif // GGML_USE_WMMA_FATTN
197198

198199
if (use_logit_softcap && !(D == 128 || D == 256)) {
199200
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
#include "fattn-common.cuh"
77
#include "fattn-wmma-f16.cuh"
88

9-
#ifdef FP16_MMA_AVAILABLE
9+
#ifdef GGML_USE_WMMA_FATTN
1010
#if !defined(GGML_USE_HIP)
1111
#include <mma.h>
12-
#ifdef GGML_USE_MUSA
12+
#if defined(GGML_USE_MUSA)
1313
namespace wmma = mtmusa::wmma;
1414
#else // GGML_USE_MUSA
1515
namespace wmma = nvcuda::wmma;
1616
#endif // GGML_USE_MUSA
17-
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
17+
#elif defined(GGML_USE_HIP)
1818
#include <rocwmma/rocwmma.hpp>
1919
namespace wmma = rocwmma;
2020
#endif // !defined(GGML_USE_HIP)
21-
#endif // FP16_MMA_AVAILABLE
21+
#endif // GGML_USE_WMMA_FATTN
2222

2323
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
2424
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
4545
const int32_t nb21, const int32_t nb22, const int64_t nb23,
4646
const int32_t ne31, const int32_t ne32, const int32_t ne33,
4747
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
48-
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
48+
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
4949
// Skip unused kernel variants for faster compilation:
5050
if (use_logit_softcap && !(D == 128 || D == 256)) {
5151
NO_DEVICE_CODE;
@@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16(
481481
ne31, ne32, ne33,
482482
nb31, nb32, nb33);
483483
NO_DEVICE_CODE;
484-
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
484+
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
485485
}
486486

487487
constexpr int get_max_power_of_2(int x) {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,49 @@
11
#include "common.cuh"
22

3+
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
4+
#define GGML_USE_WMMA_FATTN
5+
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
6+
7+
#if defined(GGML_HIP_ROCWMMA_FATTN)
8+
#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
9+
#define GGML_USE_WMMA_FATTN
10+
#elif defined(CDNA)
11+
#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
12+
#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
13+
#if defined(RDNA3)
14+
#define GGML_USE_WMMA_FATTN
15+
#endif // defined(RDNA3)
16+
#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
17+
#define GGML_USE_WMMA_FATTN
18+
#elif defined(RDNA4)
19+
#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
20+
#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
21+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
22+
23+
// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
24+
static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
25+
#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
26+
return false;
27+
#else
28+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
29+
GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
30+
return true;
31+
} else if (GGML_CUDA_CC_IS_CDNA(cc)){
32+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
33+
return true;
34+
#else
35+
return false;
36+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
37+
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
38+
#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
39+
return true;
40+
#else
41+
return false;
42+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
43+
} else {
44+
return false;
45+
}
46+
#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
47+
}
48+
349
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
222222
if (V->ne[0] != K->ne[0]) {
223223
return BEST_FATTN_KERNEL_NONE;
224224
}
225-
if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
225+
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
226226
return BEST_FATTN_KERNEL_NONE;
227227
}
228228
break;
@@ -300,7 +300,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
300300
}
301301

302302
// For large batch sizes, use the WMMA kernel if possible:
303-
if (fp16_mma_available(cc)) {
303+
if (ggml_cuda_should_use_wmma_fattn(cc)) {
304304
return BEST_FATTN_KERNEL_WMMA_F16;
305305
}
306306

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
#include <hip/hip_fp16.h>
77
#include <hip/hip_bf16.h>
88

9+
#if defined(GGML_HIP_ROCWMMA_FATTN)
10+
#include <rocwmma/rocwmma-version.hpp>
11+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
12+
913
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
1014
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
1115
#define CUBLAS_OP_N HIPBLAS_OP_N

ggml/src/ggml-hip/CMakeLists.txt

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ endif()
3939
find_package(hip REQUIRED)
4040
find_package(hipblas REQUIRED)
4141
find_package(rocblas REQUIRED)
42-
if (GGML_HIP_ROCWMMA_FATTN)
43-
CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
44-
if (NOT ${FOUND_ROCWMMA})
45-
message(FATAL_ERROR "rocwmma has not been found")
46-
endif()
47-
endif()
4842

4943
if (${hip_VERSION} VERSION_LESS 6.1)
5044
message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
@@ -117,10 +111,6 @@ if (NOT GGML_HIP_MMQ_MFMA)
117111
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
118112
endif()
119113

120-
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
121-
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
122-
endif()
123-
124114
if (GGML_HIP_EXPORT_METRICS)
125115
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
126116
endif()

0 commit comments

Comments
 (0)