Skip to content

Commit a3c9d1d

Browse files
committed
HIP/WMMA: retune WMMA FlashAttention on RDNA3\n\n- Increase block residency on HIP via __launch_bounds__ (min 2 blocks/SM)\n- Adaptive KQ stride on HIP: 128 for D<=128 to reduce LDS footprint\n- Update loops and launch to use the adaptive stride; bump nwarps for small D\n- No behavior change on CUDA; improves prefill perf on RDNA3
1 parent 1c1409e commit a3c9d1d

File tree

1 file changed

+39
-17
lines changed

1 file changed

+39
-17
lines changed

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,24 @@ namespace wmma = rocwmma;
2020
#endif // !defined(GGML_USE_HIP)
2121
#endif // GGML_USE_WMMA_FATTN
2222

23+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
24+
static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 2;
25+
#else
26+
static constexpr int GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM = 1;
27+
#endif
28+
29+
template <int D>
30+
constexpr int ggml_wmma_fattn_kq_stride() {
31+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
32+
return D <= 128 ? 128 : FATTN_KQ_STRIDE;
33+
#else
34+
return FATTN_KQ_STRIDE;
35+
#endif
36+
}
37+
2338
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
2439
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
25-
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
40+
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), GGML_ROCWMMA_FATTN_MIN_BLOCKS_PER_SM)
2641
static __global__ void flash_attn_ext_f16(
2742
const char * __restrict__ Q,
2843
const char * __restrict__ K,
@@ -55,10 +70,12 @@ static __global__ void flash_attn_ext_f16(
5570
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
5671

5772
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
73+
constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();
74+
75+
static_assert(D <= fattn_kq_stride, "D must be <= fattn_kq_stride.");
5876

5977
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
6078

61-
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
6279
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
6380
constexpr int frag_m = ncols == 8 ? 32 : 16;
6481
constexpr int frag_n = ncols == 8 ? 8 : 16;
@@ -75,7 +92,7 @@ static __global__ void flash_attn_ext_f16(
7592

7693
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
7794
constexpr int D_padded = D + 8;
78-
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
95+
constexpr int kqs_padded = fattn_kq_stride + 8;
7996
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
8097

8198
const int sequence = blockIdx.z / ne02;
@@ -168,10 +185,10 @@ static __global__ void flash_attn_ext_f16(
168185

169186
// Iterate over ne11 == previous tokens:
170187
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
171-
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
188+
for (int k_VKQ_0 = blockIdx.y*fattn_kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*fattn_kq_stride) {
172189
// Calculate tile of KQ:
173190
#pragma unroll
174-
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
191+
for (int i_KQ_0 = 0; i_KQ_0 < fattn_kq_stride; i_KQ_0 += KQ_stride_tc) {
175192
frag_c_KQ KQ_c[ncols/frag_n];
176193
#pragma unroll
177194
for (int j = 0; j < ncols/frag_n; ++j) {
@@ -201,9 +218,9 @@ static __global__ void flash_attn_ext_f16(
201218
const int j = j0 + threadIdx.y;
202219

203220
if (std::is_same<KQ_acc_t, float>::value) {
204-
float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
221+
float KQ_f_tmp[fattn_kq_stride / warp_size];
205222
#pragma unroll
206-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
223+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
207224
const int k = k0 + threadIdx.x;
208225

209226
KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
@@ -215,7 +232,7 @@ static __global__ void flash_attn_ext_f16(
215232

216233
float KQ_max_new = KQ_max_f[j0/nwarps];
217234
#pragma unroll
218-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
235+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
219236
const int k = k0 + threadIdx.x;
220237

221238
KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
@@ -232,7 +249,7 @@ static __global__ void flash_attn_ext_f16(
232249

233250
float KQ_rowsum_add = 0.0f;
234251
#pragma unroll
235-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
252+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += warp_size) {
236253
const int k = k0 + threadIdx.x;
237254

238255
const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
@@ -248,9 +265,9 @@ static __global__ void flash_attn_ext_f16(
248265
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
249266
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
250267
} else {
251-
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
268+
half2 KQ2_tmp[fattn_kq_stride/(2*warp_size)];
252269
#pragma unroll
253-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
270+
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
254271
const int k = k0 + threadIdx.x;
255272

256273
KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
@@ -267,7 +284,7 @@ static __global__ void flash_attn_ext_f16(
267284

268285
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
269286
#pragma unroll
270-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
287+
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
271288
const int k = k0 + threadIdx.x;
272289

273290
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
@@ -282,7 +299,7 @@ static __global__ void flash_attn_ext_f16(
282299

283300
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
284301
#pragma unroll
285-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
302+
for (int k0 = 0; k0 < fattn_kq_stride/2; k0 += warp_size) {
286303
const int k = k0 + threadIdx.x;
287304

288305
const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
@@ -301,11 +318,11 @@ static __global__ void flash_attn_ext_f16(
301318

302319
__syncthreads();
303320

304-
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
321+
frag_b KQ_b[fattn_kq_stride/(VKQ_ratio*16)][ncols/frag_n];
305322
#pragma unroll
306323
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
307324
#pragma unroll
308-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
325+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += VKQ_ratio*16) {
309326
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
310327
wmma::load_matrix_sync(
311328
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
@@ -323,7 +340,7 @@ static __global__ void flash_attn_ext_f16(
323340
}
324341

325342
#pragma unroll
326-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
343+
for (int k0 = 0; k0 < fattn_kq_stride; k0 += VKQ_ratio*16) {
327344
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
328345

329346
frag_a_V v_a;
@@ -512,7 +529,12 @@ template <int D, int cols_per_block, typename KQ_acc_t>
512529
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
513530
const ggml_tensor * KQV = dst;
514531

532+
#if defined(GGML_USE_HIP) && defined(GGML_HIP_ROCWMMA_FATTN)
533+
constexpr int nwarps = D <= 96 ? 8 : 4;
534+
#else
515535
constexpr int nwarps = 4;
536+
#endif
537+
constexpr int fattn_kq_stride = ggml_wmma_fattn_kq_stride<D>();
516538

517539
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
518540
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
@@ -530,7 +552,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
530552
fattn_kernel = flash_attn_ext_f16<
531553
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
532554
}
533-
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
555+
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, fattn_kq_stride, true, true, false, warp_size);
534556
}
535557

536558
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

0 commit comments

Comments
 (0)