Skip to content

Commit dfdbd58

Browse files
JohannesGaesslerqnixsynapse
authored andcommitted
CUDA: refactor FA support/selection code (ggml-org#15454)
1 parent 2b76bf5 commit dfdbd58

File tree

4 files changed

+161
-107
lines changed

4 files changed

+161
-107
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -704,28 +704,6 @@ static __global__ void flash_attn_combine_results(
704704
dst[tid] = VKQ_numerator / VKQ_denominator;
705705
}
706706

707-
[[noreturn]]
708-
static void on_no_fattn_vec_case(const int D) {
709-
if (D == 64) {
710-
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
711-
fprintf(stderr, "By default only f16 KV cache is supported.\n");
712-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
713-
GGML_ABORT("fatal error");
714-
} else if (D == 128) {
715-
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
716-
fprintf(stderr, "Supported combinations:\n");
717-
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
718-
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
719-
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
720-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
721-
GGML_ABORT("fatal error");
722-
} else {
723-
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
724-
fprintf(stderr, "Only f16 is supported.\n");
725-
GGML_ABORT("fatal error");
726-
}
727-
}
728-
729707
template <int DV, int ncols1, int ncols2>
730708
void launch_fattn(
731709
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,

ggml/src/ggml-cuda/fattn.cu

Lines changed: 157 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
190190
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
191191
#endif // GGML_CUDA_FA_ALL_QUANTS
192192

193-
on_no_fattn_vec_case(Q->ne[0]);
193+
GGML_ABORT("fatal error");
194194
}
195195

196196
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
@@ -265,74 +265,184 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
265265
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
266266
#endif // GGML_CUDA_FA_ALL_QUANTS
267267

268-
on_no_fattn_vec_case(Q->ne[0]);
268+
GGML_ABORT("fatal error");
269269
}
270270

271-
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
271+
// Best FlashAttention kernel for a specific GPU:
272+
enum best_fattn_kernel {
273+
BEST_FATTN_KERNEL_NONE = 0,
274+
BEST_FATTN_KERNEL_TILE_F32 = 200,
275+
BEST_FATTN_KERNEL_TILE_F16 = 210,
276+
BEST_FATTN_KERNEL_VEC_F32 = 100,
277+
BEST_FATTN_KERNEL_VEC_F16 = 110,
278+
BEST_FATTN_KERNEL_WMMA_F16 = 300,
279+
BEST_FATTN_KERNEL_MMA_F16 = 400,
280+
};
281+
282+
static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
283+
#ifndef FLASH_ATTN_AVAILABLE
284+
GGML_UNUSED(device); GGML_UNUSED(dst);
285+
return BEST_FATTN_KERNEL_NONE;
286+
#endif// FLASH_ATTN_AVAILABLE
287+
272288
const ggml_tensor * KQV = dst;
273289
const ggml_tensor * Q = dst->src[0];
274290
const ggml_tensor * K = dst->src[1];
275291
const ggml_tensor * V = dst->src[2];
276292
const ggml_tensor * mask = dst->src[3];
277293

278-
ggml_cuda_set_device(ctx.device);
279-
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
280-
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
294+
const int gqa_ratio = Q->ne[2] / K->ne[2];
295+
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
296+
297+
const int cc = ggml_cuda_info().devices[device].cc;
298+
const int warp_size = ggml_cuda_info().devices[device].warp_size;
281299
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
282300

283-
#if defined(GGML_HIP_ROCWMMA_FATTN)
284-
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
285-
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
286-
return;
301+
switch (K->ne[0]) {
302+
case 64:
303+
case 128:
304+
case 256:
305+
if (V->ne[0] != K->ne[0]) {
306+
return BEST_FATTN_KERNEL_NONE;
307+
}
308+
break;
309+
case 80:
310+
case 96:
311+
case 112:
312+
if (V->ne[0] != K->ne[0]) {
313+
return BEST_FATTN_KERNEL_NONE;
314+
}
315+
if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
316+
return BEST_FATTN_KERNEL_NONE;
317+
}
318+
break;
319+
case 576:
320+
if (V->ne[0] != 512) {
321+
return BEST_FATTN_KERNEL_NONE;
322+
}
323+
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
324+
return BEST_FATTN_KERNEL_NONE;
325+
}
326+
break;
327+
default:
328+
return BEST_FATTN_KERNEL_NONE;
287329
}
288-
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
289330

290-
if (!fast_fp16_available(cc)) {
291-
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
292-
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
293-
} else {
294-
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
295-
}
296-
return;
331+
#ifndef GGML_CUDA_FA_ALL_QUANTS
332+
if (K->type != V->type) {
333+
return BEST_FATTN_KERNEL_NONE;
297334
}
335+
#endif // GGML_CUDA_FA_ALL_QUANTS
298336

299-
if (!fp16_mma_available(cc)) {
300-
if (prec == GGML_PREC_DEFAULT) {
301-
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
302-
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
303-
} else {
304-
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
337+
switch (K->type) {
338+
case GGML_TYPE_F16:
339+
break;
340+
case GGML_TYPE_Q4_1:
341+
case GGML_TYPE_Q5_0:
342+
case GGML_TYPE_Q5_1:
343+
#ifndef GGML_CUDA_FA_ALL_QUANTS
344+
return BEST_FATTN_KERNEL_NONE;
345+
#endif // GGML_CUDA_FA_ALL_QUANTS
346+
case GGML_TYPE_Q4_0:
347+
case GGML_TYPE_Q8_0:
348+
#ifdef GGML_CUDA_FA_ALL_QUANTS
349+
if (K->ne[0] != 128 && K->ne[0] != 64) {
350+
return BEST_FATTN_KERNEL_NONE;
305351
}
306-
} else {
307-
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
308-
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
309-
} else {
310-
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
352+
#else
353+
if (K->ne[0] != 128) {
354+
return BEST_FATTN_KERNEL_NONE;
311355
}
312-
}
313-
return;
356+
#endif // GGML_CUDA_FA_ALL_QUANTS
357+
break;
358+
default:
359+
return BEST_FATTN_KERNEL_NONE;
360+
}
361+
362+
switch (V->type) {
363+
case GGML_TYPE_F16:
364+
break;
365+
case GGML_TYPE_Q4_1:
366+
case GGML_TYPE_Q5_0:
367+
case GGML_TYPE_Q5_1:
368+
case GGML_TYPE_Q4_0:
369+
case GGML_TYPE_Q8_0:
370+
if (K->ne[0] != 128) {
371+
return BEST_FATTN_KERNEL_NONE;
372+
}
373+
break;
374+
default:
375+
return BEST_FATTN_KERNEL_NONE;
376+
}
377+
378+
if (mask && mask->ne[2] != 1) {
379+
return BEST_FATTN_KERNEL_NONE;
314380
}
315381

316-
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
317-
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
318-
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
319-
const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
320-
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
321382
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
322-
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
323-
if (prec == GGML_PREC_DEFAULT) {
324-
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
325-
} else {
326-
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
383+
384+
// If Turing tensor cores available, use them except for some cases with batch size 1:
385+
if (turing_mma_available(cc)) {
386+
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations
387+
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
388+
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192);
389+
const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion &&
390+
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
391+
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
392+
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
393+
return BEST_FATTN_KERNEL_VEC_F16;
394+
}
395+
return BEST_FATTN_KERNEL_VEC_F32;
327396
}
328-
return;
397+
return BEST_FATTN_KERNEL_MMA_F16;
329398
}
330399

331-
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
332-
if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
333-
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
334-
return;
400+
// Use kernels specializes for small batch sizes if possible:
401+
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
402+
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
403+
return BEST_FATTN_KERNEL_VEC_F16;
404+
}
405+
return BEST_FATTN_KERNEL_VEC_F32;
406+
}
407+
408+
// For large batch sizes, use the WMMA kernel if possible:
409+
if (fp16_mma_available(cc)) {
410+
return BEST_FATTN_KERNEL_WMMA_F16;
411+
}
412+
413+
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
414+
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
415+
return BEST_FATTN_KERNEL_TILE_F16;
335416
}
417+
return BEST_FATTN_KERNEL_TILE_F32;
418+
}
419+
420+
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
421+
ggml_cuda_set_device(ctx.device);
422+
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
423+
case BEST_FATTN_KERNEL_NONE:
424+
GGML_ABORT("fatal error");
425+
case BEST_FATTN_KERNEL_TILE_F32:
426+
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
427+
break;
428+
case BEST_FATTN_KERNEL_TILE_F16:
429+
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
430+
break;
431+
case BEST_FATTN_KERNEL_VEC_F32:
432+
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
433+
break;
434+
case BEST_FATTN_KERNEL_VEC_F16:
435+
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
436+
break;
437+
case BEST_FATTN_KERNEL_WMMA_F16:
438+
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
439+
break;
440+
case BEST_FATTN_KERNEL_MMA_F16:
441+
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
442+
break;
443+
}
444+
}
336445

337-
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
446+
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
447+
return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
338448
}

ggml/src/ggml-cuda/fattn.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
#include "common.cuh"
22

33
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
4+
5+
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3499,44 +3499,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
34993499
case GGML_OP_GATED_LINEAR_ATTN:
35003500
case GGML_OP_RWKV_WKV7:
35013501
return true;
3502-
case GGML_OP_FLASH_ATTN_EXT: {
3503-
#ifndef FLASH_ATTN_AVAILABLE
3504-
return false;
3505-
#endif // FLASH_ATTN_AVAILABLE
3506-
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3507-
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3508-
if (!turing_mma_available(cc)) {
3509-
return false;
3510-
}
3511-
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
3512-
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3513-
}
3514-
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
3515-
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
3516-
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
3517-
return false;
3518-
}
3519-
if (op->src[0]->ne[0] == 192) {
3520-
return false;
3521-
}
3522-
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3523-
return false;
3524-
}
3525-
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
3526-
return true;
3527-
}
3528-
if (op->src[0]->ne[0] == 128) {
3529-
return true;
3530-
}
3531-
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
3532-
return true;
3533-
}
3534-
if (op->src[3] && op->src[3]->ne[2] != 1) {
3535-
return false;
3536-
}
3537-
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
3538-
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3539-
}
3502+
case GGML_OP_FLASH_ATTN_EXT:
3503+
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
35403504
case GGML_OP_CROSS_ENTROPY_LOSS:
35413505
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
35423506
case GGML_OP_OPT_STEP_ADAMW:

0 commit comments

Comments
 (0)