Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@


# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
_flash_fn = None
_flash_varlen_fn = None
_fdma_fn = None
_fdma_varlen_fn = None
_pad_fn = None
_unpad_fn = None
_create_mask_fn = None

# function that processes kwargs, generalized to handle any supported kwarg within the function
_process_flash_kwargs_fn = None
Expand All @@ -53,13 +54,12 @@ def _lazy_imports(implementation: Optional[str]):
"""
is_fdma = is_flash_dmattn_available()

pad_input, unpad_input = _pad_input, _unpad_input

if (implementation == "flash_dmattn" and is_fdma) or (implementation is None and is_fdma):
from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func
from flash_dmattn.utils.padding import pad_input, unpad_input
from flash_dmattn.utils.mask import create_mask

return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input
return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input, create_mask


def _lazy_define_process_function(flash_function):
Expand Down Expand Up @@ -90,15 +90,15 @@ def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], forc
NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
"""
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
global _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn
if force_import or any(k is None for k in [_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]):
_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation)

global _process_flash_kwargs_fn
if force_import or _process_flash_kwargs_fn is None:
_process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
_process_flash_kwargs_fn = _lazy_define_process_function(_fdma_varlen_fn)

return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
return (_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn


def _index_first_axis(tensor, indices):
Expand All @@ -113,57 +113,6 @@ def _index_first_axis(tensor, indices):
return reshaped_tensor[indices]


def _unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.

Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.

Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))

return (
_index_first_axis(hidden_states, indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)


def _pad_input(hidden_states, indices, batch, seqlen):
"""
pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.

Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.

Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[1:]
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
output[indices] = hidden_states
return output.view(batch, seqlen, *dim)


def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.
Expand Down Expand Up @@ -527,7 +476,7 @@ def _flash_dynamic_mask_attention_forward(
"If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None."
)

(flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation)
(fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation)

# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
query_states, key_states, value_states, attention_bias = fdma_peft_integration_check(
Expand All @@ -546,10 +495,10 @@ def _flash_dynamic_mask_attention_forward(
**kwargs,
)

# We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
# We will use `fdma_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
# Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
# use `flash_varlen_fn` knowing we already have all necessary the kwargs.
# use `fdma_varlen_fn` knowing we already have all necessary the kwargs.
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the extra article 'the' before 'kwargs' in line 501. Should read 'all necessary kwargs' instead of 'all necessary the kwargs'.

Suggested change
# use `fdma_varlen_fn` knowing we already have all necessary the kwargs.
# use `fdma_varlen_fn` knowing we already have all necessary kwargs.

Copilot uses AI. Check for mistakes.
#
# NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
# See #39121 for more information.
Expand All @@ -569,7 +518,7 @@ def _flash_dynamic_mask_attention_forward(
if "mps" in str(q.device):
cu_seq_lens_k = cu_seq_lens_k.clone()

out_unpad = flash_varlen_fn(
out_unpad = fdma_varlen_fn(
q,
k,
v,
Expand Down Expand Up @@ -600,7 +549,7 @@ def _flash_dynamic_mask_attention_forward(
if "mps" in str(q.device):
cu_seq_lens_k = cu_seq_lens_k.clone()

out = flash_varlen_fn(
out = fdma_varlen_fn(
q,
k,
v,
Expand All @@ -624,24 +573,17 @@ def _flash_dynamic_mask_attention_forward(
and window_size is not None
and key_length > window_size
):
min_dtype = torch.finfo(query_states.dtype).min
if attention_mask is not None:
topk_values, topk_indices = torch.topk(
attention_bias.masked_fill(~attention_mask, min_dtype).detach(),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
attention_bias, dtype=torch.bool, device=attention_bias.device
).scatter_(-1, topk_indices, topk_values != min_dtype)
else:
topk_values, topk_indices = torch.topk(
attention_bias.detach(), window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
attention_bias, dtype=torch.bool, device=attention_bias.device
).scatter_(-1, topk_indices, topk_values != min_dtype)
attention_mask = create_mask_fn(
attention_bias,
attention_mask,
batch_size=query_states.size(0),
query_len=query_length,
key_len=key_length,
window_size=window_size,
min_dtype=torch.finfo(attention_bias.dtype).min,
)

out = flash_fn(
out = fdma_fn(
query_states,
key_states,
value_states,
Expand Down
108 changes: 108 additions & 0 deletions flash_dmattn/utils/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch


def dynamic_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
window_size: int,
min_dtype: float,
):
r"""
This function generates a dynamic mask based on the top-k attention bias.
Args:
attention_bias (torch.Tensor): The attention bias tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
window_size (int): The number of top elements to consider for the mask.
min_dtype (float): The minimum value to use for masking.
Returns:
attention_mask (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
"""
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
attention_bias.detach(),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
attention_bias, dtype=torch.bool, device=attention_bias.device
Comment on lines +41 to +47
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This line creates a potentially unnecessary copy of attention_bias when attention_mask is None. Consider using an early return pattern or storing the result in a new variable to make the intent clearer and avoid reassignment.

Suggested change
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
attention_bias.detach(),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
attention_bias, dtype=torch.bool, device=attention_bias.device
masked_attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
masked_attention_bias.detach(),
window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(
masked_attention_bias, dtype=torch.bool, device=masked_attention_bias.device

Copilot uses AI. Check for mistakes.
).scatter_(-1, topk_indices, topk_values != min_dtype)
return attention_mask


def create_mask(
attention_bias: torch.Tensor,
attention_mask: Optional[torch.Tensor],
batch_size: int,
query_len: int,
key_len: int,
window_size: int,
min_dtype: float,
) -> torch.Tensor:
r"""
This function creates a mask tensor for Flash Dynamic Mask Attention.
If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias.
Args:
Args:
attention_bias (torch.Tensor): The attention bias tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
batch_size (int): The batch size.
query_len (int): The sequence length of the query.
key_len (int): The sequence length of the key.
window_size (int): The number of top elements to consider for the attention mask.
min_dtype (float): The minimum value to use for masking.
Returns:
attention (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
"""

# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
if attention_mask is not None and attention_mask.dim() == 2:
if attention_mask.shape[-1] == key_len:
attention_mask = attention_mask.view(batch_size, 1, 1, key_len)
elif attention_mask.shape[-1] == query_len:
pad_len = key_len - query_len
if pad_len > 0:
pad_mask = torch.ones(
(batch_size, 1, 1, pad_len),
dtype=torch.bool,
device=attention_mask.device,
)
Comment on lines +89 to +93
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding mask is initialized with torch.ones (all True), which marks padded positions as valid. This contradicts the typical attention mask convention where False/0 indicates invalid positions. Consider using torch.zeros instead to mark padding as invalid.

Copilot uses AI. Check for mistakes.
attention_mask = torch.cat(
[pad_mask, attention_mask.view(batch_size, 1, 1, query_len)],
dim=-1,
)
else:
attention_mask = attention_mask.view(batch_size, 1, 1, query_len)
else:
raise ValueError(
f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}."
)

attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype)

return attention_mask
2 changes: 1 addition & 1 deletion flash_dmattn/utils/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@ def upad_input(
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
)