@@ -14,7 +14,7 @@ def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
1414def _sanitize_tensors (* tensors : Optional [torch .Tensor ], nan : float = 0.0 , posinf : float = 1e6 , neginf : float = - 1e6 ) -> None :
1515 for t in tensors :
1616 if t is not None and isinstance (t , torch .Tensor ):
17- torch .nan_to_num (t , nan = nan , posinf = posinf , neginf = neginf , out = t )
17+ torch .nan_to_num_ (t , nan = nan , posinf = posinf , neginf = neginf )
1818
1919
2020def _get_block_size_n (device , head_dim , is_causal ):
@@ -95,7 +95,7 @@ def _flash_dmattn_forward(
9595 softcap ,
9696 return_softmax ,
9797 )
98- _sanitize_tensors (out )
98+ _sanitize_tensors (out , nan = 0.0 , posinf = torch . finfo ( out . dtype ). max , neginf = torch . finfo ( out . dtype ). min )
9999 return out , softmax_lse , S_dmask
100100
101101
@@ -170,7 +170,7 @@ def _flash_dmattn_backward(
170170 softcap ,
171171 deterministic ,
172172 )
173- _sanitize_tensors (dq , dk , dv , dbias )
173+ _sanitize_tensors (dq , dk , dv , dbias , nan = 0.0 , posinf = torch . finfo ( dq . dtype ). max , neginf = torch . finfo ( dq . dtype ). min )
174174 return softmax_d
175175
176176
@@ -255,7 +255,7 @@ def forward(
255255 if mask is not None :
256256 mask = torch .nn .functional .pad (mask , [0 , 128 - seqlen_k % 128 ], value = False )
257257 if bias is not None :
258- bias = torch .nn .functional .pad (bias , [0 , 128 - seqlen_k % 128 ], value = torch . finfo ( bias . dtype ). min )
258+ bias = torch .nn .functional .pad (bias , [0 , 128 - seqlen_k % 128 ], value = 0.0 )
259259
260260 out_padded , softmax_lse , S_dmask = _wrapped_flash_dmattn_forward (
261261 q ,
@@ -288,7 +288,7 @@ def backward(
288288 * args : Any ,
289289 ):
290290 q , k , v , mask , bias , out , softmax_lse = ctx .saved_tensors
291- dq , dk , dv , dbias = torch .empty_like (q ), torch .empty_like (k ), torch .empty_like (v ), torch .empty_like (bias )
291+ dq , dk , dv , dbias = torch .zeros_like (q ), torch .zeros_like (k ), torch .zeros_like (v ), torch .zeros_like (bias )
292292
293293 head_size_og = dout .size (3 )
294294 dout_padded = dout
@@ -339,11 +339,16 @@ def flash_dmattn_func(
339339 return_attn_probs : Optional [bool ] = None ,
340340):
341341 """
342- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
342+ Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
343343 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
344344 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
345345 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
346346
347+ Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA.
348+ For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension
349+ of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0
350+ of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias.
351+
347352 If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
348353 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
349354 1 1 1 1 0
0 commit comments