Skip to content

Commit 1cf1385

Browse files
authored
Merge pull request #166 from SmallDoges/convert-mask-from-float-to-bool
[FEATURE SUPPORT] Convert attention mask storage from float to bool
2 parents ab06c18 + b98f2a9 commit 1cf1385

File tree

11 files changed

+80
-499
lines changed

11 files changed

+80
-499
lines changed

benchmarks/backward_equivalence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def prepare_dynamic_mask(
8787
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
8888
)
8989
valid_topk = topk_values != min_dtype
90-
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
91-
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
92-
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
90+
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
91+
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
92+
attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype)
9393
else:
94-
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
94+
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
9595
return attn_bias, attn_mask
9696

9797

benchmarks/backward_performance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def prepare_dynamic_mask(
109109
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
110110
)
111111
valid_topk = topk_values != min_dtype
112-
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
113-
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
114-
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
112+
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
113+
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
114+
attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype)
115115
else:
116-
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
116+
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
117117
return attn_bias, attn_mask
118118

119119

benchmarks/forward_equivalence.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def prepare_dynamic_mask(
8787
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
8888
)
8989
valid_topk = topk_values != min_dtype
90-
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
91-
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
92-
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
90+
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
91+
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
92+
attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype)
9393
else:
94-
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
94+
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
9595
return attn_bias, attn_mask
9696

9797

benchmarks/forward_performance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def prepare_dynamic_mask(
109109
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
110110
)
111111
valid_topk = topk_values != min_dtype
112-
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
113-
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
114-
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
112+
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
113+
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
114+
attn_bias = attn_bias.masked_fill(~attn_mask, min_dtype)
115115
else:
116-
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
116+
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
117117
return attn_bias, attn_mask
118118

119119

csrc/flash_dmattn/flash_api.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ mha_fwd(
361361
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
362362
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
363363
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
364-
TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs");
364+
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
365365
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");
366366

367367
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias);
@@ -512,7 +512,7 @@ mha_varlen_fwd(
512512
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
513513
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
514514
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
515-
TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs");
515+
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
516516
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");
517517
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
518518
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
@@ -749,7 +749,7 @@ mha_bwd(
749749
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
750750
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
751751
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
752-
TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype");
752+
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
753753
TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype");
754754
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
755755
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
@@ -951,7 +951,7 @@ mha_varlen_bwd(
951951
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
952952
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
953953
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
954-
TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype");
954+
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
955955
TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype");
956956
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
957957
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
@@ -1136,7 +1136,7 @@ mha_varlen_bwd(
11361136
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
11371137
m.doc() = "FlashDynamicMaskAttention";
11381138
m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
1139-
m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length");
1139+
// m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length");
11401140
m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
1141-
m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length");
1141+
// m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length");
11421142
}

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
144144
make_stride(params.v_row_stride, _1{})
145145
);
146146
Tensor gMask = make_tensor(
147-
make_gmem_ptr(reinterpret_cast<Element *>(params.mask_ptr) + row_offset_mask),
147+
make_gmem_ptr(reinterpret_cast<const bool *>(params.mask_ptr) + row_offset_mask),
148148
Shape<Int<kBlockM>, Int<kBlockN>>{},
149149
make_stride(params.mask_row_stride, _1{})
150150
);
@@ -552,14 +552,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
552552
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
553553
// // if (cute::thread(1, 0)) { print(tKrK); }
554554

555-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
555+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
556556
gmem_tiled_copy_MaskBias,
557557
tMaskgMask, tMasksMask,
558558
tMaskcMask,
559559
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
560560
);
561-
cute::cp_async_fence();
562-
FLASH_NAMESPACE::cp_async_wait<0>();
561+
// cute::cp_async_fence();
562+
// FLASH_NAMESPACE::cp_async_wait<0>();
563+
__syncthreads();
563564

564565
// Do OR-reduce on the mask to see if any active threads
565566
Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask);
@@ -807,14 +808,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
807808
if (m_block > m_block_min) {
808809
// Advance gMask
809810
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
810-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
811+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
811812
gmem_tiled_copy_MaskBias,
812813
tMaskgMask, tMasksMask,
813814
tMaskcMask,
814815
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
815816
);
816-
FLASH_NAMESPACE::cp_async_fence();
817-
FLASH_NAMESPACE::cp_async_wait<0>();
817+
// FLASH_NAMESPACE::cp_async_fence();
818+
// FLASH_NAMESPACE::cp_async_wait<0>();
819+
__syncthreads();
818820

819821
// Do OR-reduce on the mask to see if any active threads for next iteration
820822
any_active_local_next = false;

csrc/flash_dmattn/src/flash_fwd_kernel.h

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
169169
make_coord(_, 0)
170170
); // (kBlockN, kHeadDim, nblocksN)
171171
Tensor mMask = make_tensor(
172-
make_gmem_ptr(reinterpret_cast<Element*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
172+
make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
173173
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
174174
make_stride(params.mask_head_stride, params.mask_row_stride, _1{})
175175
);
@@ -344,15 +344,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
344344
}
345345
// Reverse iteration over N blocks
346346
int n_block = n_block_max - 1;
347-
348-
FLASH_NAMESPACE::copy_MN<Is_even_MN>(
347+
348+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
349349
gmem_tiled_copy_MaskBias,
350350
tMaskgMask(_, _, _, n_block), tMasksMask,
351351
tMaskcMask,
352352
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
353353
);
354-
cute::cp_async_fence();
355-
FLASH_NAMESPACE::cp_async_wait<0>();
354+
// cute::cp_async_fence();
355+
// FLASH_NAMESPACE::cp_async_wait<0>();
356356
__syncthreads();
357357

358358
// Do OR-reduce on the mask to see if any active threads
@@ -470,14 +470,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
470470
}
471471

472472
if (n_block > n_block_min) {
473-
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
473+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
474474
gmem_tiled_copy_MaskBias,
475475
tMaskgMask(_, _, _, n_block - 1), tMasksMask,
476476
tMaskcMask,
477477
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
478478
);
479-
cute::cp_async_fence();
480-
FLASH_NAMESPACE::cp_async_wait<0>();
479+
// cute::cp_async_fence();
480+
// FLASH_NAMESPACE::cp_async_wait<0>();
481481
__syncthreads();
482482

483483
// Do OR-reduce on the mask to see if any active threads for next iteration
@@ -593,14 +593,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
593593
}
594594

595595
if (n_block > n_block_min) {
596-
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
596+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
597597
gmem_tiled_copy_MaskBias,
598598
tMaskgMask(_, _, _, n_block - 1), tMasksMask,
599599
tMaskcMask,
600600
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
601601
);
602-
cute::cp_async_fence();
603-
FLASH_NAMESPACE::cp_async_wait<0>();
602+
// cute::cp_async_fence();
603+
// FLASH_NAMESPACE::cp_async_wait<0>();
604604
__syncthreads();
605605

606606
// Do OR-reduce on the mask to see if any active threads for next iteration
@@ -873,7 +873,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
873873
make_stride(params.v_row_stride, _1{})
874874
);
875875
Tensor gMask = make_tensor(
876-
make_gmem_ptr(reinterpret_cast<Element *>(params.mask_ptr) + col_offset_mask),
876+
make_gmem_ptr(reinterpret_cast<const bool *>(params.mask_ptr) + col_offset_mask),
877877
Shape<Int<kBlockM>, Int<kBlockN>>{},
878878
make_stride(params.mask_row_stride, _1{})
879879
);
@@ -999,14 +999,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
999999

10001000
int n_block = n_block_max - 1;
10011001

1002-
FLASH_NAMESPACE::copy_MN<Is_even_MN>(
1002+
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
10031003
gmem_tiled_copy_MaskBias,
10041004
tMaskgMask, tMasksMask,
10051005
tMaskcMask,
10061006
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
10071007
);
1008-
cute::cp_async_fence();
1009-
FLASH_NAMESPACE::cp_async_wait<0>();
1008+
// cute::cp_async_fence();
1009+
// FLASH_NAMESPACE::cp_async_wait<0>();
10101010
__syncthreads();
10111011

10121012
// Do OR-reduce on the mask to see if any active threads for next iteration
@@ -1146,14 +1146,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
11461146
tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur);
11471147
tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur);
11481148
}
1149-
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
1149+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
11501150
gmem_tiled_copy_MaskBias,
11511151
tMaskgMask, tMasksMask,
11521152
tMaskcMask,
11531153
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
11541154
);
1155-
cute::cp_async_fence();
1156-
FLASH_NAMESPACE::cp_async_wait<0>();
1155+
// cute::cp_async_fence();
1156+
// FLASH_NAMESPACE::cp_async_wait<0>();
11571157
__syncthreads();
11581158

11591159
// Do OR-reduce on the mask to see if any active threads for next iteration
@@ -1287,12 +1287,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
12871287
tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur);
12881288
tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur);
12891289
}
1290-
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
1290+
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
12911291
gmem_tiled_copy_MaskBias,
12921292
tMaskgMask, tMasksMask,
12931293
tMaskcMask,
12941294
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
12951295
);
1296+
// cute::cp_async_fence();
1297+
// FLASH_NAMESPACE::cp_async_wait<0>();
1298+
__syncthreads();
12961299

12971300
// Do OR-reduce on the mask to see if any active threads for next iteration
12981301
any_active_local_next = false;

csrc/flash_dmattn/src/utils.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ __forceinline__ __device__ void copy(
521521
////////////////////////////////////////////////////////////////////////////////////////////////////
522522

523523
template <
524-
bool Is_even_MN=true, bool Clear_OOB_MN=true,
524+
bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void,
525525
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
526526
typename Engine2, typename Layout2
527527
>
@@ -543,7 +543,14 @@ __forceinline__ __device__ void copy_MN(
543543
#pragma unroll
544544
for (int n = 0; n < size<2>(S); ++n) {
545545
if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) {
546-
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
546+
if constexpr (Bool_to_Element) {
547+
#pragma unroll
548+
for (int i = 0; i < size<0>(S); ++i) {
549+
D(i, m, n) = static_cast<bool>(S(i, m, n)) ? To_type(1) : To_type(0);
550+
}
551+
} else {
552+
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
553+
}
547554
} else if (Clear_OOB_MN) {
548555
cute::clear(D(_, m, n));
549556
}

0 commit comments

Comments
 (0)