|
10 | 10 |
|
11 | 11 | template <int D, int ncols2> |
12 | 12 | static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
13 | | - const ggml_tensor * Q = dst->src[0]; |
| 13 | + const ggml_tensor * KQV = dst; |
| 14 | + const ggml_tensor * Q = dst->src[0]; |
| 15 | + |
| 16 | + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); |
14 | 17 |
|
15 | | - if (Q->ne[1] <= 8/ncols2) { |
16 | | - ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst); |
| 18 | + if (prec != GGML_PREC_DEFAULT) { |
| 19 | + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { |
| 20 | + constexpr int cols_per_block = 16; |
| 21 | + switch (Q->ne[0]) { |
| 22 | + case 64: |
| 23 | + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); |
| 24 | + break; |
| 25 | + case 80: |
| 26 | + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); |
| 27 | + break; |
| 28 | + case 96: |
| 29 | + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); |
| 30 | + break; |
| 31 | + case 112: |
| 32 | + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); |
| 33 | + break; |
| 34 | + case 128: |
| 35 | + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); |
| 36 | + break; |
| 37 | + case 192: |
| 38 | + ggml_cuda_flash_attn_ext_wmma_f16_case<192, cols_per_block, float>(ctx, dst); |
| 39 | + break; |
| 40 | + case 256: |
| 41 | + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); |
| 42 | + break; |
| 43 | + default: |
| 44 | + GGML_ABORT("fatal error"); |
| 45 | + break; |
| 46 | + } |
| 47 | + } else { |
| 48 | + constexpr int cols_per_block = 32; |
| 49 | + switch (Q->ne[0]) { |
| 50 | + case 64: |
| 51 | + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); |
| 52 | + break; |
| 53 | + case 80: |
| 54 | + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); |
| 55 | + break; |
| 56 | + case 96: |
| 57 | + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); |
| 58 | + break; |
| 59 | + case 112: |
| 60 | + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); |
| 61 | + break; |
| 62 | + case 128: |
| 63 | + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); |
| 64 | + break; |
| 65 | + // case 256: |
| 66 | + // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); |
| 67 | + // break; |
| 68 | + default: |
| 69 | + GGML_ABORT("fatal error"); |
| 70 | + break; |
| 71 | + } |
| 72 | + } |
17 | 73 | return; |
18 | 74 | } |
19 | 75 |
|
|
0 commit comments