Skip to content

Commit 8b38ef8

Browse files
lhlhjc4869
authored andcommitted
Massively Improved ROCm/HIP rocWMMA Performance (pp and tg) (ggml-org#16827)
* HIP/WMMA: retune WMMA FlashAttention on RDNA3\n\n- Increase block residency on HIP via __launch_bounds__ (min 2 blocks/SM)\n- Adaptive KQ stride on HIP: 128 for D<=128 to reduce LDS footprint\n- Update loops and launch to use the adaptive stride; bump nwarps for small D\n- No behavior change on CUDA; improves prefill perf on RDNA3 * HIP: use WMMA for prefill only; fix decode regression by enabling TILE and adding a safe fallback\n\n- Do not select WMMA for decode on HIP; fall through to VEC/TILE\n- Remove WMMA TILE pruning on HIP to avoid device traps; keep for CUDA WMMA\n- Add decode-time guard: if predicted TILE split has no config, select VEC\n- Remove ad-hoc env overrides and debug prints
1 parent 9b2601e commit 8b38ef8

File tree

3 files changed

+98
-19
lines changed

3 files changed

+98
-19
lines changed

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,9 +747,12 @@ static __global__ void flash_attn_tile(
747747

748748
// Skip unused kernel variants for faster compilation:
749749

750+
// Optionally disable pruning to keep all TILE variants for testing.
751+
#if !defined(GGML_USE_HIP)
750752
if (
751753
#ifdef GGML_USE_WMMA_FATTN
752-
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
754+
// On CUDA WMMA builds, prune some TILE variants to reduce compile time/binary size.
755+
(ncols2 != 1 && DV != 40 && DV != 64 && DV != 72 && DV != 128 && DV != 256 && DV != 512) ||
753756
#endif // GGML_USE_WMMA_FATTN
754757
(use_logit_softcap && !(DV == 128 || DV == 256))
755758
) {
@@ -765,6 +768,7 @@ static __global__ void flash_attn_tile(
765768
NO_DEVICE_CODE;
766769
return;
767770
}
771+
#endif // !defined(GGML_USE_HIP)
768772

769773
static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
770774

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,24 @@ namespace wmma = rocwmma;
2626
#endif // !defined(GGML_USE_HIP)
2727
#endif // GGML_USE_WMMA_FATTN
2828

29+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
30+
static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 2;
31+
#else
32+
static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 1;
33+
#endif
34+
35+
template <int D>
36+
constexpr int ggml_wmma_fattn_kq_stride() {
37+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
38+
return D <= 128 ? 128 : FATTN_KQ_STRIDE;
39+
#else
40+
return FATTN_KQ_STRIDE;
41+
#endif
42+
}
43+
2944
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
3045
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
31-
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
46+
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM)
3247
static __global__ void flash_attn_ext_f16(
3348
const char * __restrict__ Q,
3449
const char * __restrict__ K,
@@ -61,10 +76,12 @@ static __global__ void flash_attn_ext_f16(
6176
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
6277

6378
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
79+
constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();
80+
81+
static_assert(D <= fattn_kq_stride, "D must be <= fattn_kq_stride.");
6482

6583
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
6684

67-
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
6885
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
6986
constexpr int frag_m = ncols == 8 ? 32 : 16;
7087
constexpr int frag_n = ncols == 8 ? 8 : 16;
@@ -81,7 +98,7 @@ static __global__ void flash_attn_ext_f16(
8198

8299
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
83100
constexpr int D_padded = D + 8;
84-
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
101+
constexpr int kqs_padded = fattn_kq_stride + 8;
85102
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
86103

87104
const int sequence = blockIdx.z / ne02;
@@ -174,10 +191,10 @@ static __global__ void flash_attn_ext_f16(
174191

175192
// Iterate over ne11 == previous tokens:
176193
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
177-
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
194+
for (int k_VKQ_0 = blockIdx.y*fattn_kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*fattn_kq_stride) {
178195
// Calculate tile of KQ:
179196
#pragma unroll
180-
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
197+
for (int i_KQ_0 = 0; i_KQ_0 < fattn_kq_stride; i_KQ_0 += KQ_stride_tc) {
181198
frag_c_KQ KQ_c[ncols/frag_n];
182199
#pragma unroll
183200
for (int j = 0; j < ncols/frag_n; ++j) {
@@ -207,9 +224,9 @@ static __global__ void flash_attn_ext_f16(
207224
const int j = j0 + threadIdx.y;
208225

209226
if (std::is_same<KQ_acc_t, float>::value) {
210-
float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
227+
float KQ_f_tmp[fattn_kq_stride / warp_size];
211228
#pragma unroll
212-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
229+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
213230
const int k = k0 + threadIdx.x;
214231

215232
KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
@@ -221,7 +238,7 @@ static __global__ void flash_attn_ext_f16(
221238

222239
float KQ_max_new = KQ_max_f[j0/nwarps];
223240
#pragma unroll
224-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
241+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
225242
const int k = k0 + threadIdx.x;
226243

227244
KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
@@ -238,7 +255,7 @@ static __global__ void flash_attn_ext_f16(
238255

239256
float KQ_rowsum_add = 0.0f;
240257
#pragma unroll
241-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
258+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
242259
const int k = k0 + threadIdx.x;
243260

244261
const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
@@ -254,9 +271,9 @@ static __global__ void flash_attn_ext_f16(
254271
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
255272
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
256273
} else {
257-
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
274+
half2 KQ2_tmp[fattn_kq_stride/(2*warp_size)];
258275
#pragma unroll
259-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
276+
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
260277
const int k = k0 + threadIdx.x;
261278

262279
KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
@@ -273,7 +290,7 @@ static __global__ void flash_attn_ext_f16(
273290

274291
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
275292
#pragma unroll
276-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
293+
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
277294
const int k = k0 + threadIdx.x;
278295

279296
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
@@ -288,7 +305,7 @@ static __global__ void flash_attn_ext_f16(
288305

289306
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
290307
#pragma unroll
291-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
308+
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
292309
const int k = k0 + threadIdx.x;
293310

294311
const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
@@ -307,11 +324,11 @@ static __global__ void flash_attn_ext_f16(
307324

308325
__syncthreads();
309326

310-
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
327+
frag_b KQ_b[fattn_kq_stride/(VKQ_ratio*16)][ncols/frag_n];
311328
#pragma unroll
312329
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
313330
#pragma unroll
314-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
331+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += VKQ_ratio*16) {
315332
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
316333
wmma::load_matrix_sync(
317334
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
@@ -329,7 +346,7 @@ static __global__ void flash_attn_ext_f16(
329346
}
330347

331348
#pragma unroll
332-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
349+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += VKQ_ratio*16) {
333350
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
334351

335352
frag_a_V v_a;
@@ -518,7 +535,12 @@ template <int D, int cols_per_block, typename KQ_acc_t>
518535
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
519536
const ggml_tensor * KQV = dst;
520537

538+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
539+
constexpr int nwarps = D <= 96 ? 8 : 4;
540+
#else
521541
constexpr int nwarps = 4;
542+
#endif
543+
constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();
522544

523545
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
524546
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
@@ -536,7 +558,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
536558
fattn_kernel = flash_attn_ext_f16<
537559
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
538560
}
539-
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
561+
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, fattn_kq_stride, true, true, false, warp_size);
540562
}
541563

542564
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-cuda/fattn.cu

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,66 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
302302
}
303303

304304
// Use the WMMA kernel if possible:
305-
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
305+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
306+
const bool hip_wmma_decode = Q->ne[1] == 1;
307+
#else
308+
const bool hip_wmma_decode = false;
309+
#endif
310+
if (!hip_wmma_decode && ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
306311
if (can_use_vector_kernel && Q->ne[1] <= 2) {
307312
return BEST_FATTN_KERNEL_VEC;
308313
}
309314
return BEST_FATTN_KERNEL_WMMA_F16;
310315
}
311316

317+
// HIP decode path (Q->ne[1] == 1): fall through to generic HIP selection below (VEC/TILE),
318+
// with a guard to avoid selecting a TILE shape that has no config.
319+
if (hip_wmma_decode) {
320+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
321+
// Mirror the ncols2 selection from launch_fattn_tile_switch_ncols2 to predict if
322+
// a multi-column TILE kernel (ncols2 != 1) would be chosen.
323+
const bool nvidia_arch = GGML_CUDA_CC_IS_NVIDIA(cc);
324+
const int gqa_limit = (nvidia_arch && gqa_ratio <= 4) ? 16 : INT_MAX;
325+
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
326+
327+
int predicted_ncols2 = 1;
328+
if (V->ne[0] == 512) {
329+
if (use_gqa_opt && gqa_ratio % 16 == 0) predicted_ncols2 = 16;
330+
} else if (V->ne[0] <= 256) {
331+
if (use_gqa_opt && gqa_ratio % 8 == 0) predicted_ncols2 = 8;
332+
else if (use_gqa_opt && gqa_ratio % 4 == 0) predicted_ncols2 = 4;
333+
else if (use_gqa_opt && gqa_ratio % 2 == 0) predicted_ncols2 = 2;
334+
}
335+
336+
// Predict cols_per_block like launch_fattn_tile_switch_ncols1 does (HIP path):
337+
int predicted_cols_per_block = 2;
338+
if (predicted_ncols2 <= 2) {
339+
predicted_cols_per_block = 2;
340+
}
341+
if (predicted_ncols2 <= 4 && Q->ne[1] > 2/predicted_ncols2) {
342+
predicted_cols_per_block = 4;
343+
}
344+
if (predicted_ncols2 <= 8 && Q->ne[1] > 4/predicted_ncols2) {
345+
predicted_cols_per_block = 8;
346+
}
347+
if (Q->ne[1] > 8/predicted_ncols2) {
348+
predicted_cols_per_block = 16;
349+
}
350+
if (Q->ne[1] > 16/predicted_ncols2) {
351+
predicted_cols_per_block = 32;
352+
}
353+
if (V->ne[0] <= 128 && Q->ne[1] > 32/predicted_ncols2) {
354+
predicted_cols_per_block = 64;
355+
}
356+
357+
const uint32_t cfg = ggml_cuda_fattn_tile_get_config((int)Q->ne[0], (int)V->ne[0], predicted_cols_per_block, cc);
358+
if (predicted_ncols2 != 1 && cfg == 0) {
359+
return BEST_FATTN_KERNEL_VEC;
360+
}
361+
#endif // defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
362+
// Otherwise, fall through.
363+
}
364+
312365
// If there are no tensor cores available, use the generic tile kernel:
313366
if (can_use_vector_kernel) {
314367
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {

0 commit comments

Comments
 (0)