Skip to content

Commit bbfbbc3

Browse files
committed
Use ElementMask and split mask copy/reduce
Standardizes mask dtype to an explicit element type in global/shared memory to fix type mismatches and ensure alignment. Aligns the shared mask buffer via a placeholder and updates the layout to avoid misaligned accesses. Replaces fused mask copy+reduce with a generic copy followed by an explicit OR-reduction and barrier for clearer synchronization and correctness. Unifies bias handling onto the generic copy path.
1 parent 69df087 commit bbfbbc3

File tree

1 file changed

+76
-97
lines changed

1 file changed

+76
-97
lines changed

csrc/flash_dmattn/src/flash_fwd_kernel.h

Lines changed: 76 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
5454
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
5555

5656
using Element = typename Kernel_traits::Element;
57+
using ElementMask = typename Kernel_traits::ElementMask;
5758
using ElementAccum = typename Kernel_traits::ElementAccum;
5859
using index_t = typename Kernel_traits::index_t;
5960

@@ -169,7 +170,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
169170
make_coord(_, 0)
170171
); // (kBlockN, kHeadDim, nblocksN)
171172
Tensor mMask = make_tensor(
172-
make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
173+
make_gmem_ptr(reinterpret_cast<ElementMask*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
173174
make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
174175
make_stride(params.mask_head_stride, params.mask_row_stride, _1{})
175176
);
@@ -216,13 +217,17 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
216217
sV.data().get(),
217218
typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}
218219
);
219-
Tensor sMask = make_tensor(
220+
Tensor sMaskPlace = make_tensor(
220221
Has_mask ? sV.data() + size(sV) : sV.data(),
221-
typename Kernel_traits::SmemLayoutAtomPS{}
222+
typename Kernel_traits::SmemLayoutPS{}
223+
); // For pointers alignment only
224+
Tensor sMask = make_tensor(
225+
make_smem_ptr(reinterpret_cast<ElementMask*>(sMaskPlace.data().get())),
226+
typename Kernel_traits::SmemLayoutPS{}
222227
);
223228
Tensor sBias = make_tensor(
224-
Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(),
225-
typename Kernel_traits::SmemLayoutAtomPS{}
229+
Has_bias ? (Has_mask ? sMaskPlace.data() + size(sMaskPlace) : sV.data() + size(sV)) : sV.data(),
230+
typename Kernel_traits::SmemLayoutPS{}
226231
);
227232

228233
// Global to Shared Memory operation
@@ -364,25 +369,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
364369
}
365370

366371
if constexpr (Has_mask) {
367-
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
368-
// gmem_tiled_copy_Mask,
369-
// tMaskgMask(_, _, _, n_block), tMasksMask,
370-
// tMaskcMask, tMaskpMask,
371-
// binfo.actual_seqlen_q - m_block * kBlockM
372-
// );
373-
// cute::cp_async_fence();
374-
// FLASH_NAMESPACE::cp_async_wait<0>();
375-
// // Do OR-reduce on the mask to see if any active threads
376-
377-
378-
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
372+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
379373
gmem_tiled_copy_Mask,
380374
tMaskgMask(_, _, _, n_block), tMasksMask,
381-
any_active,
382375
tMaskcMask, tMaskpMask,
383376
binfo.actual_seqlen_q - m_block * kBlockM
384377
);
385-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
378+
__syncthreads();
379+
// Do OR-reduce on the mask to see if any active threads for current iteration.
380+
FLASH_NAMESPACE::mask_or_reduce(
381+
tMasksMask,
382+
any_active,
383+
smem_thr_copy_Mask
384+
);
386385
}
387386

388387
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
@@ -394,7 +393,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
394393
binfo.actual_seqlen_k - n_block * kBlockN
395394
);
396395
if constexpr (Has_bias) {
397-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
396+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
398397
gmem_tiled_copy_Bias,
399398
tBiasgBias(_, _, _, n_block), tBiassBias,
400399
tBiascBias, tBiaspBias,
@@ -524,24 +523,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
524523

525524
if (n_block > n_block_min) {
526525
if constexpr (Has_mask) {
527-
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
528-
// gmem_tiled_copy_Mask,
529-
// tMaskgMask(_, _, _, n_block - 1), tMasksMask,
530-
// tMaskcMask, tMaskpMask,
531-
// binfo.actual_seqlen_q - m_block * kBlockM
532-
// );
533-
// cute::cp_async_fence();
534-
// FLASH_NAMESPACE::cp_async_wait<0>();
535-
// // Do OR-reduce on the mask to see if any active threads for next iteration.
536-
537-
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
526+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
538527
gmem_tiled_copy_Mask,
539528
tMaskgMask(_, _, _, n_block - 1), tMasksMask,
540-
any_active_next,
541529
tMaskcMask, tMaskpMask,
542530
binfo.actual_seqlen_q - m_block * kBlockM
543531
);
544-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
532+
__syncthreads();
533+
// Do OR-reduce on the mask to see if any active threads for next iteration.
534+
FLASH_NAMESPACE::mask_or_reduce(
535+
tMasksMask,
536+
any_active_next,
537+
smem_thr_copy_Mask
538+
);
545539
}
546540

547541
if (any_active_next) {
@@ -551,7 +545,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
551545
tKVcKV, tKVpKV
552546
);
553547
if constexpr (Has_bias) {
554-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
548+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
555549
gmem_tiled_copy_Bias,
556550
tBiasgBias(_, _, _, n_block - 1), tBiassBias,
557551
tBiascBias, tBiaspBias,
@@ -684,24 +678,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
684678

685679
if (n_block > n_block_min) {
686680
if constexpr (Has_mask) {
687-
// FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
688-
// gmem_tiled_copy_Mask,
689-
// tMaskgMask(_, _, _, n_block - 1), tMasksMask,
690-
// tMaskcMask, tMaskpMask,
691-
// binfo.actual_seqlen_q - m_block * kBlockM
692-
// );
693-
// cute::cp_async_fence();
694-
// FLASH_NAMESPACE::cp_async_wait<0>();
695-
// // Do OR-reduce on the mask to see if any active threads for next iteration.
696-
697-
FLASH_NAMESPACE::copy_mask_with_or_reduce</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false, /*To_type=*/Element>(
681+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
698682
gmem_tiled_copy_Mask,
699683
tMaskgMask(_, _, _, n_block - 1), tMasksMask,
700-
any_active_next,
701684
tMaskcMask, tMaskpMask,
702685
binfo.actual_seqlen_q - m_block * kBlockM
703686
);
704-
// We don't need to syncthreads here because copy_mask is already or_syncthreads
687+
__syncthreads();
688+
// Do OR-reduce on the mask to see if any active threads for next iteration.
689+
FLASH_NAMESPACE::mask_or_reduce(
690+
tMasksMask,
691+
any_active_next,
692+
smem_thr_copy_Mask
693+
);
705694
}
706695

707696
if (any_active_next) {
@@ -711,7 +700,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
711700
tKVcKV, tKVpKV
712701
);
713702
if constexpr (Has_bias) {
714-
FLASH_NAMESPACE::copy_bias</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
703+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
715704
gmem_tiled_copy_Bias,
716705
tBiasgBias(_, _, _, n_block - 1), tBiassBias,
717706
tBiascBias, tBiaspBias,
@@ -834,6 +823,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
834823
inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
835824

836825
using Element = typename Kernel_traits::Element;
826+
using ElementMask = typename Kernel_traits::ElementMask;
837827
using ElementAccum = typename Kernel_traits::ElementAccum;
838828
using index_t = typename Kernel_traits::index_t;
839829

@@ -970,7 +960,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
970960
make_stride(params.v_row_stride, _1{})
971961
);
972962
Tensor gMask = make_tensor(
973-
make_gmem_ptr(reinterpret_cast<const bool *>(params.mask_ptr) + col_offset_mask),
963+
make_gmem_ptr(reinterpret_cast<ElementMask *>(params.mask_ptr) + col_offset_mask),
974964
Shape<Int<kBlockM>, Int<kBlockN>>{},
975965
make_stride(params.mask_row_stride, _1{})
976966
);
@@ -1001,13 +991,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
1001991
sV.data().get(),
1002992
typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}
1003993
);
1004-
Tensor sMask = make_tensor(
994+
Tensor sMaskPlace = make_tensor(
1005995
Has_mask ? sV.data() + size(sV) : sV.data(),
1006-
typename Kernel_traits::SmemLayoutAtomPS{}
996+
typename Kernel_traits::SmemLayoutPS{}
997+
); // For pointers alignment only
998+
Tensor sMask = make_tensor(
999+
make_smem_ptr(reinterpret_cast<ElementMask*>(sMaskPlace.data().get())),
1000+
typename Kernel_traits::SmemLayoutPS{}
10071001
);
10081002
Tensor sBias = make_tensor(
1009-
Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(),
1010-
typename Kernel_traits::SmemLayoutAtomPS{}
1003+
Has_bias ? (Has_mask ? sMaskPlace.data() + size(sMaskPlace) : sV.data() + size(sV)) : sV.data(),
1004+
typename Kernel_traits::SmemLayoutPS{}
10111005
);
10121006

10131007
// Global to Shared Memory operation
@@ -1115,24 +1109,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
11151109
);
11161110

11171111
if constexpr (Has_mask) {
1118-
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
1119-
// gmem_tiled_copy_Mask,
1120-
// tMaskgMask, tMasksMask,
1121-
// tMaskcMask, tMaskpMask,
1122-
// binfo.actual_seqlen_q - m_block * kBlockM
1123-
// );
1124-
// cute::cp_async_fence();
1125-
// FLASH_NAMESPACE::cp_async_wait<0>();
1126-
// // Do OR-reduce on the mask to see if any active threads
1127-
1128-
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
1112+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
11291113
gmem_tiled_copy_Mask,
11301114
tMaskgMask, tMasksMask,
1131-
any_active,
11321115
tMaskcMask, tMaskpMask,
11331116
binfo.actual_seqlen_q - m_block * kBlockM
11341117
);
1135-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
1118+
__syncthreads();
1119+
// Do OR-reduce on the mask to see if any active threads for current iteration.
1120+
FLASH_NAMESPACE::mask_or_reduce(
1121+
tMasksMask,
1122+
any_active,
1123+
smem_thr_copy_Mask
1124+
);
11361125
}
11371126

11381127
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
@@ -1144,7 +1133,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
11441133
binfo.actual_seqlen_k - n_block * kBlockN
11451134
);
11461135
if constexpr (Has_bias) {
1147-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
1136+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
11481137
gmem_tiled_copy_Bias,
11491138
tBiasgBias, tBiassBias,
11501139
tBiascBias, tBiaspBias,
@@ -1305,24 +1294,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
13051294
}
13061295

13071296
if constexpr (Has_mask) {
1308-
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
1309-
// gmem_tiled_copy_Mask,
1310-
// tMaskgMask, tMasksMask,
1311-
// tMaskcMask, tMaskpMask,
1312-
// binfo.actual_seqlen_q - m_block * kBlockM
1313-
// );
1314-
// cute::cp_async_fence();
1315-
// FLASH_NAMESPACE::cp_async_wait<0>();
1316-
// // Do OR-reduce on the mask to see if any active threads for next iteration.
1317-
1318-
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
1297+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
13191298
gmem_tiled_copy_Mask,
13201299
tMaskgMask, tMasksMask,
1321-
any_active_next,
13221300
tMaskcMask, tMaskpMask,
13231301
binfo.actual_seqlen_q - m_block * kBlockM
13241302
);
1325-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
1303+
__syncthreads();
1304+
// Do OR-reduce on the mask to see if any active threads for next iteration.
1305+
FLASH_NAMESPACE::mask_or_reduce(
1306+
tMasksMask,
1307+
any_active_next,
1308+
smem_thr_copy_Mask
1309+
);
13261310
}
13271311

13281312
if (any_active_next) {
@@ -1332,9 +1316,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
13321316
tKVcKV, tKVpKV
13331317
);
13341318
if constexpr (Has_bias) {
1335-
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
1319+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
13361320
gmem_tiled_copy_Bias,
1337-
tBiasgBias, tBiassBias,
1321+
tBiasgBias, tBiassBias,
13381322
tBiascBias, tBiaspBias,
13391323
binfo.actual_seqlen_q - m_block * kBlockM
13401324
);
@@ -1492,24 +1476,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
14921476
}
14931477

14941478
if constexpr (Has_mask) {
1495-
// FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
1496-
// gmem_tiled_copy_Mask,
1497-
// tMaskgMask, tMasksMask,
1498-
// tMaskcMask, tMaskpMask,
1499-
// binfo.actual_seqlen_q - m_block * kBlockM
1500-
// );
1501-
// cute::cp_async_fence();
1502-
// FLASH_NAMESPACE::cp_async_wait<0>();
1503-
// // Do OR-reduce on the mask to see if any active threads for next iteration.
1504-
1505-
FLASH_NAMESPACE::copy_mask_with_or_reduce</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false, /*To_type=*/Element>(
1479+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
15061480
gmem_tiled_copy_Mask,
15071481
tMaskgMask, tMasksMask,
1508-
any_active_next,
15091482
tMaskcMask, tMaskpMask,
15101483
binfo.actual_seqlen_q - m_block * kBlockM
15111484
);
1512-
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
1485+
__syncthreads();
1486+
// Do OR-reduce on the mask to see if any active threads for next iteration.
1487+
FLASH_NAMESPACE::mask_or_reduce(
1488+
tMasksMask,
1489+
any_active_next,
1490+
smem_thr_copy_Mask
1491+
);
15131492
}
15141493

15151494
if (any_active_next) {
@@ -1519,9 +1498,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
15191498
tKVcKV, tKVpKV
15201499
);
15211500
if constexpr (Has_bias) {
1522-
FLASH_NAMESPACE::copy_bias</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
1501+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
15231502
gmem_tiled_copy_Bias,
1524-
tBiasgBias, tBiassBias,
1503+
tBiasgBias, tBiassBias,
15251504
tBiascBias, tBiaspBias,
15261505
binfo.actual_seqlen_q - m_block * kBlockM
15271506
);

0 commit comments

Comments
 (0)