Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def prepare_dynamic_mask(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


Expand Down
8 changes: 4 additions & 4 deletions benchmarks/backward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def prepare_dynamic_mask(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


Expand Down
8 changes: 4 additions & 4 deletions benchmarks/forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def prepare_dynamic_mask(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


Expand Down
8 changes: 4 additions & 4 deletions benchmarks/forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def prepare_dynamic_mask(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk)
attn_bias = attn_bias.masked_fill(attn_mask == False, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device)
return attn_bias, attn_mask


Expand Down
12 changes: 6 additions & 6 deletions csrc/flash_dmattn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ mha_fwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs");
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias);
Expand Down Expand Up @@ -512,7 +512,7 @@ mha_varlen_fwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs");
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
Expand Down Expand Up @@ -749,7 +749,7 @@ mha_bwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype");
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
Expand Down Expand Up @@ -951,7 +951,7 @@ mha_varlen_bwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(mask.dtype() == q_dtype, "query and mask must have the same dtype");
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
Expand Down Expand Up @@ -1136,7 +1136,7 @@ mha_varlen_bwd(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashDynamicMaskAttention";
m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length");
// m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length");
m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length");
// m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length");
}
16 changes: 9 additions & 7 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
make_stride(params.v_row_stride, _1{})
);
Tensor gMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.mask_ptr) + row_offset_mask),
make_gmem_ptr(reinterpret_cast<const bool *>(params.mask_ptr) + row_offset_mask),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.mask_row_stride, _1{})
);
Expand Down Expand Up @@ -552,14 +552,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); }

FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads
Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask);
Expand Down Expand Up @@ -807,14 +808,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (m_block > m_block_min) {
// Advance gMask
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
FLASH_NAMESPACE::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// FLASH_NAMESPACE::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
any_active_local_next = false;
Expand Down
41 changes: 22 additions & 19 deletions csrc/flash_dmattn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
make_coord(_, 0)
); // (kBlockN, kHeadDim, nblocksN)
Tensor mMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.mask_head_stride, params.mask_row_stride, _1{})
);
Expand Down Expand Up @@ -344,15 +344,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}
// Reverse iteration over N blocks
int n_block = n_block_max - 1;
FLASH_NAMESPACE::copy_MN<Is_even_MN>(

FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask(_, _, _, n_block), tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads
Expand Down Expand Up @@ -470,14 +470,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}

if (n_block > n_block_min) {
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask(_, _, _, n_block - 1), tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
Expand Down Expand Up @@ -593,14 +593,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}

if (n_block > n_block_min) {
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask(_, _, _, n_block - 1), tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
Expand Down Expand Up @@ -873,7 +873,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
make_stride(params.v_row_stride, _1{})
);
Tensor gMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.mask_ptr) + col_offset_mask),
make_gmem_ptr(reinterpret_cast<const bool *>(params.mask_ptr) + col_offset_mask),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.mask_row_stride, _1{})
);
Expand Down Expand Up @@ -999,14 +999,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons

int n_block = n_block_max - 1;

FLASH_NAMESPACE::copy_MN<Is_even_MN>(
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
Expand Down Expand Up @@ -1146,14 +1146,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur);
tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur);
}
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
cute::cp_async_fence();
FLASH_NAMESPACE::cp_async_wait<0>();
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
Expand Down Expand Up @@ -1287,12 +1287,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur);
tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur);
}
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
gmem_tiled_copy_MaskBias,
tMaskgMask, tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
// cute::cp_async_fence();
// FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();

// Do OR-reduce on the mask to see if any active threads for next iteration
any_active_local_next = false;
Expand Down
13 changes: 11 additions & 2 deletions csrc/flash_dmattn/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ __forceinline__ __device__ void copy(
////////////////////////////////////////////////////////////////////////////////////////////////////

template <
bool Is_even_MN=true, bool Clear_OOB_MN=true,
bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
Expand All @@ -543,7 +543,16 @@ __forceinline__ __device__ void copy_MN(
#pragma unroll
for (int n = 0; n < size<2>(S); ++n) {
if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) {
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
if constexpr (Bool_to_Element) {
#pragma unroll
for (int i = 0; i < size<0>(S); ++i) {
D(i, m, n) = static_cast<bool>(S(i, m, n))
? static_cast<To_type>(1.0f)
: static_cast<To_type>(0.0f);
}
} else {
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, n));
}
Expand Down
Loading