Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def prepare_dynamic_mask(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


Expand Down
8 changes: 4 additions & 4 deletions benchmarks/backward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def prepare_dynamic_mask(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


Expand Down
8 changes: 4 additions & 4 deletions benchmarks/forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def prepare_dynamic_mask(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


Expand Down
8 changes: 4 additions & 4 deletions benchmarks/forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def prepare_dynamic_mask(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


Expand Down
12 changes: 6 additions & 6 deletions csrc/flash_dmattn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ mha_fwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs");
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias);
Expand Down Expand Up @@ -512,7 +512,7 @@ mha_varlen_fwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs");
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
Expand Down Expand Up @@ -749,7 +749,7 @@ mha_bwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype");
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
Expand Down Expand Up @@ -951,7 +951,7 @@ mha_varlen_bwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype");
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
Expand Down Expand Up @@ -1136,7 +1136,7 @@ mha_varlen_bwd(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashDynamicMaskAttention";
m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length");
// m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length");
m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length");
// m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length");
}
16 changes: 9 additions & 7 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
make_stride(params.v_row_stride, _1{})
);
Tensor gMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.mask_ptr) + row_offset_mask),
make_gmem_ptr(reinterpret_cast<const bool *>(params.mask_ptr) + row_offset_mask),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.mask_row_stride, _1{})
);
Expand Down Expand Up @@ -552,14 +552,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); }

FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads
Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask);
Expand Down Expand Up @@ -807,14 +808,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (m_block > m_block_min) {
// Advance gMask
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
FLASH_NAMESPACE::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// FLASH_NAMESPACE::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
any_active_local_next = false;
Expand Down
41 changes: 22 additions & 19 deletions csrc/flash_dmattn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
make_coord(_, 0)
); // (kBlockN, kHeadDim, nblocksN)
Tensor mMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.mask_head_stride, params.mask_row_stride, _1{})
);
Expand Down Expand Up @@ -344,15 +344,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}
// Reverse iteration over N blocks
int n_block = n_block_max - 1;
FLASH_NAMESPACE::copy_MN<Is_even_MN>(

FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask(_, _, _, n_block), tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads
Expand Down Expand Up @@ -470,14 +470,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}

if (n_block > n_block_min) {
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask(_, _, _, n_block - 1), tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
Expand Down Expand Up @@ -593,14 +593,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}

if (n_block > n_block_min) {
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask(_, _, _, n_block - 1), tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
Expand Down Expand Up @@ -873,7 +873,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
make_stride(params.v_row_stride, _1{})
);
Tensor gMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.mask_ptr) + col_offset_mask),
make_gmem_ptr(reinterpret_cast<const bool *>(params.mask_ptr) + col_offset_mask),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.mask_row_stride, _1{})
);
Expand Down Expand Up @@ -999,14 +999,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons

int n_block = n_block_max - 1;

FLASH_NAMESPACE::copy_MN<Is_even_MN>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
Expand Down Expand Up @@ -1146,14 +1146,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur);
tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur);
}
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
Expand Down Expand Up @@ -1287,12 +1287,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur);
tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur);
}
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
any_active_local_next = false;
Expand Down
11 changes: 9 additions & 2 deletions csrc/flash_dmattn/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ __forceinline__ __device__ void copy(
////////////////////////////////////////////////////////////////////////////////////////////////////

template <
bool Is_even_MN=true, bool Clear_OOB_MN=true,
bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
Expand All @@ -543,7 +543,14 @@ __forceinline__ __device__ void copy_MN(
#pragma unroll
for (int n = 0; n < size<2>(S); ++n) {
if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) {
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
if constexpr (Bool_to_Element) {
#pragma unroll
for (int i = 0; i < size<0>(S); ++i) {
D(i, m, n) = static_cast<bool>(S(i, m, n)) ? To_type(1) : To_type(0);
}
} else {
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, n));
}
Expand Down
Loading