Skip to content

Commit a313b71

Browse files
ikawrakowIwan Kawrakow
andauthored
DeepSeek FA optimizations (#929)
* Use new-new-mma also for MLA=3, and use mask bounds This gives us ~25% better PP at 32k tokens compared to main * This seems better --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 7747000 commit a313b71

File tree

2 files changed

+214
-43
lines changed

2 files changed

+214
-43
lines changed

ggml/src/ggml-cuda/fattn-new-mma.cu

Lines changed: 211 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,46 @@
1414

1515
using namespace ggml_cuda_mma;
1616

17+
typedef void (* fattn_new_mma_kernel_t)(
18+
const char * __restrict__ Q,
19+
const char * __restrict__ K,
20+
const char * __restrict__ V,
21+
const char * __restrict__ mask,
22+
const char * __restrict__ sinks,
23+
const int * __restrict__ KV_max,
24+
float * __restrict__ dst,
25+
float2 * __restrict__ dst_meta,
26+
const float scale,
27+
const float max_bias,
28+
const float m0,
29+
const float m1,
30+
const float softcap,
31+
const uint32_t n_head_log2,
32+
const int ne00,
33+
const int ne01,
34+
const int ne02,
35+
const int ne03,
36+
const int ne10,
37+
const int ne11,
38+
const int ne12,
39+
const int ne13,
40+
const int ne31,
41+
const int nb31,
42+
const int nb01,
43+
const int nb02,
44+
const int nb03,
45+
const int nb11,
46+
const int nb12,
47+
const int nb13,
48+
const int nb21,
49+
const int nb22,
50+
const int nb23,
51+
const int ne0,
52+
const int ne1,
53+
const int ne2,
54+
const int ne3);
55+
56+
1757
typedef tile<16, 8, half2> tile_A;
1858
typedef tile< 8, 8, half2> tile_B;
1959
typedef tile<16, 8, half2> tile_B_16;
@@ -43,37 +83,37 @@ struct fattn_mma_f16_config;
4383
// Perhaps the 256 head size needs a closer look
4484
// to see if this implementation is better.
4585
//
46-
template <>
47-
struct fattn_mma_f16_config< 64, 64> {
48-
static constexpr int nbatch_fa = 64;
49-
static constexpr int nwarps_max = 4;
50-
static constexpr bool Q_in_reg = true;
51-
static constexpr int nstages_target = 2;
52-
53-
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
54-
return 32;
55-
}
56-
57-
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
58-
return 32;
59-
}
60-
61-
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
62-
return 32;
63-
}
64-
65-
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
66-
return 32;
67-
}
68-
69-
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
70-
return 32;
71-
}
72-
73-
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
74-
return 32;
75-
}
76-
};
86+
//template <>
87+
//struct fattn_mma_f16_config< 64, 64> {
88+
// static constexpr int nbatch_fa = 64;
89+
// static constexpr int nwarps_max = 4;
90+
// static constexpr bool Q_in_reg = true;
91+
// static constexpr int nstages_target = 2;
92+
//
93+
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
94+
// return 32;
95+
// }
96+
//
97+
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
98+
// return 32;
99+
// }
100+
//
101+
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
102+
// return 32;
103+
// }
104+
//
105+
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
106+
// return 32;
107+
// }
108+
//
109+
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
110+
// return 32;
111+
// }
112+
//
113+
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
114+
// return 32;
115+
// }
116+
//};
77117
//
78118
//template <>
79119
//struct fattn_mma_f16_config< 80, 80> {
@@ -243,6 +283,38 @@ struct fattn_mma_f16_config< 64, 64> {
243283
// }
244284
//};
245285

286+
template <>
287+
struct fattn_mma_f16_config<192, 128> {
288+
static constexpr int nbatch_fa = 64;
289+
static constexpr int nwarps_max = 4;
290+
static constexpr bool Q_in_reg = true;
291+
static constexpr int nstages_target = 1;
292+
293+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
294+
return 64;
295+
}
296+
297+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
298+
return 64;
299+
}
300+
301+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
302+
return 64;
303+
}
304+
305+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
306+
return 64;
307+
}
308+
309+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
310+
return 64;
311+
}
312+
313+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
314+
return 64;
315+
}
316+
};
317+
246318
template <>
247319
struct fattn_mma_f16_config<576, 512> {
248320
static constexpr int nbatch_fa = 32;
@@ -1287,6 +1359,7 @@ static __global__ void flash_attn_ext_f16(
12871359
const char * __restrict__ V,
12881360
const char * __restrict__ mask,
12891361
const char * __restrict__ sinks,
1362+
const int * __restrict__ KV_max,
12901363
float * __restrict__ dst,
12911364
float2 * __restrict__ dst_meta,
12921365
const float scale,
@@ -1377,8 +1450,11 @@ static __global__ void flash_attn_ext_f16(
13771450

13781451
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
13791452

1380-
const int kb0_start_kernel = kb0_start * kb_niter;
1381-
const int kb0_stop_kernel = kb0_stop * kb_niter;
1453+
int kb0_start_kernel = kb0_start * kb_niter;
1454+
int kb0_stop_kernel = kb0_stop * kb_niter;
1455+
if (KV_max) {
1456+
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[jt] / c::nbatch_fa);
1457+
}
13821458

13831459
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
13841460
if (kb0_start == 0) {
@@ -1417,8 +1493,11 @@ static __global__ void flash_attn_ext_f16(
14171493

14181494
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
14191495

1420-
const int kb0_start_kernel = kb0_start * kb_niter;
1421-
const int kb0_stop_kernel = kb0_stop * kb_niter;
1496+
int kb0_start_kernel = kb0_start * kb_niter;
1497+
int kb0_stop_kernel = kb0_stop * kb_niter;
1498+
if (KV_max) {
1499+
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[jt] / c::nbatch_fa);
1500+
}
14221501

14231502
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
14241503
constexpr bool needs_fixup = false;
@@ -1574,9 +1653,68 @@ static __global__ void flash_attn_combine_results_new(
15741653
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
15751654
}
15761655

1656+
template<int width = WARP_SIZE>
1657+
static __device__ __forceinline__ int warp_reduce_all(int x) {
1658+
if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) {
1659+
return __all_sync(0xffffffff, x);
1660+
} else {
1661+
#pragma unroll
1662+
for (int offset = width/2; offset > 0; offset >>= 1) {
1663+
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
1664+
}
1665+
return x;
1666+
}
1667+
}
1668+
1669+
template <int ncols1>
1670+
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
1671+
static __global__ void flash_attn_mask_to_KV_max(
1672+
const half2 * __restrict__ mask, int * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) {
1673+
const int ne31 = gridDim.x;
1674+
const int tid = threadIdx.x;
1675+
const int sequence = blockIdx.y;
1676+
const int jt = blockIdx.x;
1677+
1678+
mask += sequence*s33 + jt*ncols1*s31;
1679+
1680+
__shared__ int buf_iw[WARP_SIZE];
1681+
if (tid < WARP_SIZE) {
1682+
buf_iw[tid] = 1;
1683+
}
1684+
__syncthreads();
1685+
1686+
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
1687+
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
1688+
int all_inf = 1;
1689+
1690+
#pragma unroll
1691+
for (int j = 0; j < ncols1; ++j) {
1692+
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
1693+
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
1694+
}
1695+
1696+
all_inf = warp_reduce_all(all_inf);
1697+
if (tid % WARP_SIZE == 0) {
1698+
buf_iw[tid / WARP_SIZE] = all_inf;
1699+
}
1700+
__syncthreads();
1701+
all_inf = buf_iw[tid % WARP_SIZE];
1702+
__syncthreads();
1703+
all_inf = warp_reduce_all(all_inf);
1704+
1705+
if (!all_inf) {
1706+
break;
1707+
}
1708+
}
1709+
1710+
if (threadIdx.x == 0) {
1711+
KV_min_max[sequence*ne31 + jt] = KV_max_sj + FATTN_KQ_STRIDE;
1712+
}
1713+
}
1714+
15771715
template <int DV, int ncols1, int ncols2>
15781716
static void launch_fattn_new_mma(
1579-
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
1717+
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_new_mma_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
15801718
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
15811719
) {
15821720
constexpr int ncols = ncols1 * ncols2;
@@ -1605,10 +1743,15 @@ static void launch_fattn_new_mma(
16051743
cudaStream_t main_stream = ctx.stream();
16061744
const int id = ggml_cuda_get_device();
16071745
const int cc = ggml_cuda_info().devices[id].cc;
1608-
const int nsm = ggml_cuda_info().devices[id].nsm;
1746+
const int nsm_actual = ggml_cuda_info().devices[id].nsm;
1747+
int nsm = 1; while (nsm*2 <= nsm_actual) nsm *= 2;
1748+
1749+
if (Q->ne[1] == 1 && K->ne[1] <= 4096 && nsm > 32) nsm /= 2;
1750+
if (Q->ne[1] >= 32 && K->ne[1] >= 4096) nsm *= 2;
16091751

16101752
ggml_cuda_pool_alloc<half> K_f16(pool);
16111753
ggml_cuda_pool_alloc<half> V_f16(pool);
1754+
ggml_cuda_pool_alloc<int> KV_max(pool);
16121755
ggml_cuda_pool_alloc<float> dst_tmp(pool);
16131756
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
16141757

@@ -1675,6 +1818,25 @@ static void launch_fattn_new_mma(
16751818
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
16761819
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
16771820

1821+
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
1822+
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
1823+
// multiple sequences of possibly different lengths.
1824+
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
1825+
const int s31 = mask->nb[1] / sizeof(half2);
1826+
const int s33 = mask->nb[3] / sizeof(half2);
1827+
1828+
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
1829+
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
1830+
1831+
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
1832+
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
1833+
1834+
KV_max.alloc(ne_KV_max);
1835+
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
1836+
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
1837+
CUDA_CHECK(cudaGetLastError());
1838+
}
1839+
16781840
const dim3 block_dim(warp_size, nwarps, 1);
16791841
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
16801842
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
@@ -1761,6 +1923,7 @@ static void launch_fattn_new_mma(
17611923
V_data,
17621924
mask ? ((const char *) mask->data) : nullptr,
17631925
sinks ? ((const char *)sinks->data) : nullptr,
1926+
KV_max.get(),
17641927
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
17651928
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
17661929
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
@@ -1837,7 +2000,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct
18372000
float logit_softcap;
18382001
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
18392002

1840-
fattn_kernel_t fattn_kernel;
2003+
fattn_new_mma_kernel_t fattn_kernel;
18412004
if (logit_softcap == 0.0f) {
18422005
constexpr bool use_logit_softcap = false;
18432006
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
@@ -1944,8 +2107,16 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
19442107
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
19452108
const int gqa_ratio = Q->ne[2] / K->ne[2];
19462109

1947-
if (K->ne[0] == 64 && V->ne[0] == 64) {
1948-
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
2110+
//if (K->ne[0] == 64 && V->ne[0] == 64) {
2111+
// ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
2112+
// return;
2113+
//}
2114+
if (K->ne[0] == 192 && V->ne[0] == 128) {
2115+
GGML_ASSERT(Q->ne[0] == 192);
2116+
GGML_ASSERT(gqa_ratio == 1);
2117+
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst);
2118+
// Reduce compile time
2119+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
19492120
return;
19502121
}
19512122
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
102102
// Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
103103
// so no other implementation works.
104104
//
105-
if (new_mma_available(cc) && Q->ne[0] == 576) {
105+
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128))) {
106106
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
107107
return;
108108
}
@@ -172,8 +172,8 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
172172
return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst);
173173
}
174174

175-
if (new_mma_available(cc) && Q->ne[0] == 576) {
176-
return V->ne[0] == 512;
175+
if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128))) {
176+
return true;
177177
}
178178

179179
if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) {

0 commit comments

Comments
 (0)