Skip to content

Commit 8f1cd4a

Browse files
committed
CUDA: Enable FP16_MMA for RDNA3 with rocWMMA
1 parent 70392f1 commit 8f1cd4a

File tree

5 files changed

+36
-31
lines changed

5 files changed

+36
-31
lines changed

ggml/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ set(CMAKE_C_STANDARD_REQUIRED true)
177177
if (GGML_SYCL)
178178
set(CMAKE_CXX_STANDARD 17)
179179
else()
180-
set(CMAKE_CXX_STANDARD 11)
180+
set(CMAKE_CXX_STANDARD 17)
181181
endif()
182182
set(CMAKE_CXX_STANDARD_REQUIRED true)
183183

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ typedef float2 dfloat2;
131131
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
132132
#define FP16_MMA_AVAILABLE
133133
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
134+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && defined(RDNA3)
135+
#define FP16_MMA_AVAILABLE
136+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && defined(RDNA3)
134137

135138
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
136139
#define INT8_MMA_AVAILABLE
@@ -145,7 +148,7 @@ static constexpr bool fast_fp16_available(const int cc) {
145148
}
146149

147150
static constexpr bool fp16_mma_available(const int cc) {
148-
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
151+
return (cc < CC_OFFSET_AMD && cc >= CC_VOLTA) || cc >= CC_RDNA3;
149152
}
150153

151154
static constexpr bool int8_mma_available(const int cc) {
@@ -242,8 +245,6 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
242245
}
243246

244247
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
245-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
246-
247248
#if CUDART_VERSION >= CUDART_HMAX
248249
return __hmax2(a, b);
249250
#else
@@ -252,12 +253,6 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
252253
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
253254
return ret;
254255
#endif // CUDART_VERSION >= CUDART_HMAX
255-
256-
#else
257-
GGML_UNUSED(a);
258-
GGML_UNUSED(b);
259-
NO_DEVICE_CODE;
260-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
261256
}
262257

263258
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {

ggml/src/ggml-cuda/fattn-wmma-f16.cuh

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
#include "fattn-common.cuh"
33

44
#ifdef FP16_MMA_AVAILABLE
5+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
6+
#include <rocwmma/rocwmma.hpp>
7+
namespace wmma = ::rocwmma;
8+
#else
59
#include <mma.h>
10+
namespace wmma = ::nvcuda::wmma;
11+
#endif
612
#endif // FP16_MMA_AVAILABLE
713

814
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
@@ -63,11 +69,11 @@ static __global__ void flash_attn_ext_f16(
6369
constexpr int frag_m = ncols == 8 ? 32 : 16;
6470
constexpr int frag_n = ncols == 8 ? 8 : 16;
6571
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
66-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
67-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
68-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
69-
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
70-
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
72+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
73+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
74+
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
75+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
76+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
7177

7278
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
7379
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -157,7 +163,7 @@ static __global__ void flash_attn_ext_f16(
157163
for (int i0 = 0; i0 < D; i0 += 16) {
158164
#pragma unroll
159165
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
160-
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
166+
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
161167
}
162168
}
163169

@@ -171,20 +177,20 @@ static __global__ void flash_attn_ext_f16(
171177
frag_c_KQ KQ_c[ncols/frag_n];
172178
#pragma unroll
173179
for (int j = 0; j < ncols/frag_n; ++j) {
174-
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
180+
wmma::fill_fragment(KQ_c[j], KQ_acc_t(0.0f));
175181
}
176182
#pragma unroll
177183
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
178184
frag_a_K K_a;
179-
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
185+
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
180186
#pragma unroll
181187
for (int j = 0; j < ncols/frag_n; ++j) {
182-
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
188+
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
183189
}
184190
}
185191
#pragma unroll
186192
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
187-
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
193+
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
188194
}
189195
}
190196

@@ -303,7 +309,7 @@ static __global__ void flash_attn_ext_f16(
303309
#pragma unroll
304310
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
305311
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
306-
nvcuda::wmma::load_matrix_sync(
312+
wmma::load_matrix_sync(
307313
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
308314
KQ + j0*(kqar*kqs_padded) + k,
309315
kqar*kqs_padded);
@@ -315,18 +321,18 @@ static __global__ void flash_attn_ext_f16(
315321
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
316322
#pragma unroll
317323
for (int j = 0; j < ncols/frag_n; ++j) {
318-
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
324+
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], half(0.0f));
319325
}
320326

321327
#pragma unroll
322328
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
323329
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
324330

325331
frag_a_V v_a;
326-
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
332+
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
327333
#pragma unroll
328334
for (int j = 0; j < ncols/frag_n; ++j) {
329-
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
335+
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
330336
}
331337
}
332338
}
@@ -338,10 +344,10 @@ static __global__ void flash_attn_ext_f16(
338344
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
339345
#pragma unroll
340346
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
341-
nvcuda::wmma::store_matrix_sync(
347+
wmma::store_matrix_sync(
342348
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
343349
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
344-
D_padded, nvcuda::wmma::mem_col_major);
350+
D_padded, wmma::mem_col_major);
345351
}
346352
}
347353

ggml/src/ggml-cuda/fattn.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
7373
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
7474
constexpr int cols_per_block = 8;
7575
switch (Q->ne[0]) {
76+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7677
case 64:
7778
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
7879
break;
@@ -85,6 +86,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
8586
case 256:
8687
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
8788
break;
89+
#endif
8890
default:
8991
GGML_ABORT("fatal error");
9092
break;
@@ -305,7 +307,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
305307

306308
// On AMD the tile kernels perform poorly, use the vec kernel instead:
307309
if (cc >= CC_OFFSET_AMD) {
308-
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
310+
if (fp16_mma_available(cc) && (Q->ne[1] > 8 || Q->ne[0] % WARP_SIZE != 0)) {
311+
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
312+
} else if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309313
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310314
} else {
311315
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);

ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "../fattn-wmma-f16.cuh"
44

5-
DECL_FATTN_WMMA_F16_CASE(64, 8, half);
6-
DECL_FATTN_WMMA_F16_CASE(96, 8, half);
7-
DECL_FATTN_WMMA_F16_CASE(128, 8, half);
8-
DECL_FATTN_WMMA_F16_CASE(256, 8, half);
5+
//DECL_FATTN_WMMA_F16_CASE(64, 8, half);
6+
//DECL_FATTN_WMMA_F16_CASE(96, 8, half);
7+
//DECL_FATTN_WMMA_F16_CASE(128, 8, half);
8+
//DECL_FATTN_WMMA_F16_CASE(256, 8, half);

0 commit comments

Comments
 (0)