Skip to content
22 changes: 14 additions & 8 deletions csrc/flash_dmattn/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,15 +521,16 @@ __forceinline__ __device__ void copy(
////////////////////////////////////////////////////////////////////////////////////////////////////

template <
bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2
bool Is_even_MN=true, bool Clear_OOB_MN=false, bool Bool_to_Element=false, typename To_type=void,
// typename TiledCopy,
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out typename TiledCopy parameter should be removed entirely rather than left as a comment, as it's no longer used in the function signature.

Suggested change
// typename TiledCopy,

Copilot uses AI. Check for mistakes.
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3
>
__forceinline__ __device__ void copy_MN(
TiledCopy tiled_copy,
// TiledCopy tiled_copy,
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out TiledCopy tiled_copy parameter should be removed entirely rather than left as a comment, as it's no longer used in the function.

Copilot uses AI. Check for mistakes.
Tensor<Engine0, Layout0> const &S, Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &identity_MN,
const int max_M=0, const int max_N=0
Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine3, Layout3> const &predicate_N,
const int max_M=0
) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N)
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N)
Expand All @@ -542,14 +543,19 @@ __forceinline__ __device__ void copy_MN(
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) {
#pragma unroll
for (int n = 0; n < size<2>(S); ++n) {
if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) {
if (Is_even_MN || predicate_N(n)) {
if constexpr (Bool_to_Element) {
#pragma unroll
for (int i = 0; i < size<0>(S); ++i) {
D(i, m, n) = static_cast<bool>(S(i, m, n)) ? To_type(1) : To_type(0);
}
} else {
cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
// Using vectorized load will cause out-of-bounds access when !Is_even_MN && !predicate_N(n)
// cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
#pragma unroll
for (int i = 0; i < size<0>(S); ++i) {
D(i, m, n) = S(i, m, n);
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, n));
Expand Down
31 changes: 10 additions & 21 deletions flash_dmattn/flash_dmattn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None:
for t in tensors:
if t is not None and isinstance(t, torch.Tensor):
torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf, out=t)
torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf)


Comment on lines 14 to 19
Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature changed from using out=t parameter to in-place operation torch.nan_to_num_(), but the function parameters still include nan, posinf, neginf which suggests the old API expected these to be configurable. However, the function calls at lines 98 and 173 now pass specific dtype-based values, making the default parameters in the function signature potentially misleading.

Copilot uses AI. Check for mistakes.
def _get_block_size_n(device, head_dim, is_causal):
Expand Down Expand Up @@ -95,7 +95,7 @@ def _flash_dmattn_forward(
softcap,
return_softmax,
)
_sanitize_tensors(out)
_sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min)
return out, softmax_lse, S_dmask


Expand Down Expand Up @@ -170,7 +170,7 @@ def _flash_dmattn_backward(
softcap,
deterministic,
)
_sanitize_tensors(dq, dk, dv, dbias)
_sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min)
return softmax_d


Expand Down Expand Up @@ -227,8 +227,6 @@ def forward(
return_softmax: Optional[bool],
is_grad_enabled: bool = True,
):
# q, k, v are expected to be of shape (batch_size, seqlen, num_heads, head_size)
seqlen_k = k.shape[1]
is_grad = is_grad_enabled and any(
x.requires_grad for x in [q, k, v]
)
Expand All @@ -249,14 +247,6 @@ def forward(
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])

if seqlen_k % 128 != 0:
k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 128 - seqlen_k % 128])
v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 128 - seqlen_k % 128])
if mask is not None:
mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False)
if bias is not None:
bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min)

out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward(
q,
k,
Expand All @@ -271,7 +261,6 @@ def forward(

if is_grad:
ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse)
ctx.seqlen_k = seqlen_k
ctx.softmax_scale = softmax_scale
ctx.is_causal = is_causal
ctx.softcap = softcap
Expand All @@ -288,7 +277,7 @@ def backward(
*args: Any,
):
q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors
dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias)
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.zeros_like() initializes all tensors with zeros, which is unnecessary overhead since these gradient tensors will be fully written by the backward kernel. Consider using torch.empty_like() for better performance.

Suggested change
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias)

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.zeros_like() instead of torch.empty_like() initializes the tensors with zeros, which adds unnecessary overhead since these tensors will be completely overwritten by the backward computation.

Suggested change
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias)

Copilot uses AI. Check for mistakes.

head_size_og = dout.size(3)
dout_padded = dout
Expand Down Expand Up @@ -318,11 +307,6 @@ def backward(
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]

if ctx.seqlen_k % 128 != 0:
dk = dk[:, : ctx.seqlen_k, :, :]
dv = dv[:, : ctx.seqlen_k, :, :]
dbias = dbias[..., : ctx.seqlen_k]

return dq, dk, dv, None, dbias, None, None, None, None, None, None


Expand All @@ -339,11 +323,16 @@ def flash_dmattn_func(
return_attn_probs: Optional[bool] = None,
):
"""
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'Similarity' on line 331 should be 'Similarly' (missing 'l').

Copilot uses AI. Check for mistakes.
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA.
For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension
of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0
of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias.

If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
Expand Down