Skip to content

Commit 413622f

Browse files
committed
Enable per-call override of causal flag and window
Allows passing window size as an argument and forwards it instead of always using the module default. Respects a provided causal flag from kwargs, falling back to the module value if absent. Clarifies attention mask/bias shapes to include 2D masks and per-head forms. Improves configurability and fixes ignored overrides.
1 parent c1815ca commit 413622f

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

flash_dmattn/integrations/flash_dynamic_mask_attention.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def flash_dynamic_mask_attention_forward(
1717
attention_mask: Optional[torch.Tensor],
1818
attention_bias: Optional[torch.Tensor],
1919
scaling: Optional[float] = None,
20+
window_size: Optional[int] = None,
2021
softcap: Optional[float] = None,
2122
**kwargs,
2223
) -> tuple[torch.Tensor, None]:
@@ -29,14 +30,16 @@ def flash_dynamic_mask_attention_forward(
2930
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
3031
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3132
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
32-
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len).
33-
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len), if attention_mask is None, also supports (batch_size, {num_heads|num_kv_heads|1}, key_len).
33+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34+
(batch_size, seq_len) or (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
35+
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape
36+
(batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
3437
scaling (Optional[float]): The scaling factor for the attention scores.
38+
window_size (Optional[int]): The size of the window to keep.
3539
softcap (Optional[float]): The softcap value for the attention scores.
3640
**kwargs: Additional keyword arguments.
3741
Includes:
3842
- is_causal (bool): Whether to apply a causal mask.
39-
- window_size (int): The size of the window to keep.
4043
- layer_idx (int): The index of the layer (for logging purposes).
4144
- implementation (str): The implementation to use ("flash_dmattn" or None).
4245
@@ -82,9 +85,10 @@ def flash_dynamic_mask_attention_forward(
8285
else:
8386
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
8487

85-
# FDMA always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
86-
kwargs.pop("is_causal", None)
87-
kwargs.pop("window_size", None)
88+
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
89+
is_causal = kwargs.pop("is_causal", None)
90+
if is_causal is None:
91+
is_causal = module.is_causal
8892

8993
attn_output = _flash_dynamic_mask_attention_forward(
9094
query,
@@ -94,10 +98,10 @@ def flash_dynamic_mask_attention_forward(
9498
attention_bias,
9599
query_length=query_len,
96100
key_length=key_len,
97-
is_causal=module.is_causal,
101+
is_causal=is_causal,
98102
softmax_scale=scaling,
99103
softcap=softcap,
100-
window_size=module.window_size,
104+
window_size=window_size,
101105
target_dtype=target_dtype,
102106
implementation="flash_dmattn",
103107
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,

0 commit comments

Comments
 (0)