Skip to content

Commit e69b1c7

Browse files
committed
Improves numerical stability and initialization
Uses in-place nan_to_num_ operation for better memory efficiency. Updates tensor sanitization to use dtype-specific infinity bounds instead of fixed values, preventing potential overflow issues. Changes tensor initialization from empty_like to zeros_like to ensure deterministic starting values for gradients. Fixes bias padding value from minimum float to zero for better numerical behavior. Enhances documentation to clarify support for flexible mask and bias head dimensions in MQA/GQA scenarios.
1 parent cb78583 commit e69b1c7

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
1414
def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None:
1515
for t in tensors:
1616
if t is not None and isinstance(t, torch.Tensor):
17-
torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf, out=t)
17+
torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf)
1818

1919

2020
def _get_block_size_n(device, head_dim, is_causal):
@@ -95,7 +95,7 @@ def _flash_dmattn_forward(
9595
softcap,
9696
return_softmax,
9797
)
98-
_sanitize_tensors(out)
98+
_sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min)
9999
return out, softmax_lse, S_dmask
100100

101101

@@ -170,7 +170,7 @@ def _flash_dmattn_backward(
170170
softcap,
171171
deterministic,
172172
)
173-
_sanitize_tensors(dq, dk, dv, dbias)
173+
_sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min)
174174
return softmax_d
175175

176176

@@ -255,7 +255,7 @@ def forward(
255255
if mask is not None:
256256
mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False)
257257
if bias is not None:
258-
bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min)
258+
bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=0.0)
259259

260260
out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward(
261261
q,
@@ -288,7 +288,7 @@ def backward(
288288
*args: Any,
289289
):
290290
q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors
291-
dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias)
291+
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
292292

293293
head_size_og = dout.size(3)
294294
dout_padded = dout
@@ -339,11 +339,16 @@ def flash_dmattn_func(
339339
return_attn_probs: Optional[bool] = None,
340340
):
341341
"""
342-
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
342+
Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
343343
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
344344
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
345345
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
346346
347+
Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA.
348+
For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension
349+
of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0
350+
of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias.
351+
347352
If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
348353
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
349354
1 1 1 1 0

0 commit comments

Comments
 (0)