Skip to content

Commit 02369da

Browse files
committed
Add rocWMMA support
1 parent 206d22b commit 02369da

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ typedef float2 dfloat2;
196196
#define FP16_MMA_AVAILABLE
197197
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
198198

199+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
200+
#define FP16_MMA_AVAILABLE
201+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
202+
199203
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
200204
#define NEW_MMA_AVAILABLE
201205
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -223,12 +227,14 @@ static bool fast_fp16_hardware_available(const int cc) {
223227

224228
// Any FP16 tensor core instructions are available for ggml code.
225229
static bool fp16_mma_available(const int cc) {
226-
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
230+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
231+
cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3;
227232
}
228233

229234
// To be used for feature selection of external libraries, e.g. cuBLAS.
230235
static bool fp16_mma_hardware_available(const int cc) {
231-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
236+
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
237+
cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1 || cc >= GGML_CUDA_CC_RDNA3;
232238
}
233239

234240
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
#include "fattn-wmma-f16.cuh"
88

99
#ifdef FP16_MMA_AVAILABLE
10+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1011
#include <mma.h>
12+
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
13+
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
14+
#include <rocwmma/rocwmma.hpp>
15+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1116
#endif // FP16_MMA_AVAILABLE
1217

1318
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
5156
const int ne1,
5257
const int ne2,
5358
const int ne3) {
54-
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
59+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))
5560
// Skip unused kernel variants for faster compilation:
5661
if (use_logit_softcap && !(D == 128 || D == 256)) {
5762
NO_DEVICE_CODE;
@@ -68,11 +73,19 @@ static __global__ void flash_attn_ext_f16(
6873
constexpr int frag_m = ncols == 8 ? 32 : 16;
6974
constexpr int frag_n = ncols == 8 ? 8 : 16;
7075
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
76+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
7177
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
7278
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
7379
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
7480
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
7581
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
82+
#else
83+
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::row_major> frag_a_K;
84+
typedef rocwmma::fragment<rocwmma::matrix_a, frag_m, frag_n, 16, half, rocwmma::col_major> frag_a_V;
85+
typedef rocwmma::fragment<rocwmma::matrix_b, frag_m, frag_n, 16, half, rocwmma::col_major> frag_b;
86+
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
87+
typedef rocwmma::fragment<rocwmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
88+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
7689

7790
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
7891
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -162,7 +175,11 @@ static __global__ void flash_attn_ext_f16(
162175
for (int i0 = 0; i0 < D; i0 += 16) {
163176
#pragma unroll
164177
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
178+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
165179
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
180+
#else
181+
rocwmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
182+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
166183
}
167184
}
168185

@@ -176,20 +193,36 @@ static __global__ void flash_attn_ext_f16(
176193
frag_c_KQ KQ_c[ncols/frag_n];
177194
#pragma unroll
178195
for (int j = 0; j < ncols/frag_n; ++j) {
196+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
179197
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
198+
#else
199+
rocwmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
200+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
180201
}
181202
#pragma unroll
182203
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
183204
frag_a_K K_a;
205+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
184206
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
207+
#else
208+
rocwmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
209+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
185210
#pragma unroll
186211
for (int j = 0; j < ncols/frag_n; ++j) {
212+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
187213
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
214+
#else
215+
rocwmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
216+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
188217
}
189218
}
190219
#pragma unroll
191220
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
221+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
192222
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
223+
#else
224+
rocwmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, rocwmma::mem_col_major);
225+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
193226
}
194227
}
195228

@@ -308,10 +341,17 @@ static __global__ void flash_attn_ext_f16(
308341
#pragma unroll
309342
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
310343
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
344+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
311345
nvcuda::wmma::load_matrix_sync(
312346
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
313347
KQ + j0*(kqar*kqs_padded) + k,
314348
kqar*kqs_padded);
349+
#else
350+
rocwmma::load_matrix_sync(
351+
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
352+
KQ + j0*(kqar*kqs_padded) + k,
353+
kqar*kqs_padded);
354+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
315355
}
316356
}
317357

@@ -320,18 +360,30 @@ static __global__ void flash_attn_ext_f16(
320360
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
321361
#pragma unroll
322362
for (int j = 0; j < ncols/frag_n; ++j) {
363+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
323364
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
365+
#else
366+
rocwmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
367+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
324368
}
325369

326370
#pragma unroll
327371
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
328372
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
329373

330374
frag_a_V v_a;
375+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
331376
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
377+
#else
378+
rocwmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
379+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
332380
#pragma unroll
333381
for (int j = 0; j < ncols/frag_n; ++j) {
382+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
334383
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
384+
#else
385+
rocwmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
386+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
335387
}
336388
}
337389
}
@@ -343,10 +395,17 @@ static __global__ void flash_attn_ext_f16(
343395
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
344396
#pragma unroll
345397
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
398+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
346399
nvcuda::wmma::store_matrix_sync(
347400
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
348401
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
349402
D_padded, nvcuda::wmma::mem_col_major);
403+
#else
404+
rocwmma::store_matrix_sync(
405+
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
406+
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
407+
D_padded, rocwmma::mem_col_major);
408+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
350409
}
351410
}
352411

@@ -425,7 +484,7 @@ static __global__ void flash_attn_ext_f16(
425484
}
426485
#else
427486
NO_DEVICE_CODE;
428-
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
487+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))
429488
}
430489

431490
constexpr int get_max_power_of_2(int x) {
@@ -574,6 +633,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
574633
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
575634
constexpr int cols_per_block = 8;
576635
switch (Q->ne[0]) {
636+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
577637
case 64:
578638
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
579639
break;
@@ -586,6 +646,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
586646
case 256:
587647
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
588648
break;
649+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
589650
default:
590651
GGML_ABORT("fatal error");
591652
break;

ggml/src/ggml-cuda/fattn.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
254254

255255
// On AMD the tile kernels perform poorly, use the vec kernel instead:
256256
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
257+
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
258+
if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) {
259+
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
260+
return;
261+
}
262+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
263+
257264
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
258265
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
259266
} else {

0 commit comments

Comments
 (0)