@@ -416,14 +416,17 @@ def forward(
416416 q = torch .nn .functional .pad (q , [0 , 8 - head_size_og % 8 ])
417417 k = torch .nn .functional .pad (k , [0 , 8 - head_size_og % 8 ])
418418 v = torch .nn .functional .pad (v , [0 , 8 - head_size_og % 8 ])
419- seqlen_k_og = k .shape [1 ]
420- if seqlen_k_og % 8 != 0 :
421- k = torch .nn .functional .pad (k , [0 , 0 , 0 , 0 , 0 , 8 - seqlen_k_og % 8 ])
422- v = torch .nn .functional .pad (v , [0 , 0 , 0 , 0 , 0 , 8 - seqlen_k_og % 8 ])
423- if mask is not None :
424- mask = torch .nn .functional .pad (mask , [0 , 8 - seqlen_k_og % 8 ], value = False )
425- if bias is not None :
426- bias = torch .nn .functional .pad (bias , [0 , 8 - seqlen_k_og % 8 ], value = 0.0 )
419+ seqlen_k_rounded = round_multiple (k .shape [1 ], 128 )
420+ if mask is not None and mask .shape [- 1 ] != seqlen_k_rounded :
421+ if mask .shape [- 1 ] == 1 :
422+ mask = mask .expand (* mask .shape [:- 1 ], seqlen_k_rounded )
423+ else :
424+ mask = torch .nn .functional .pad (mask , [0 , seqlen_k_rounded - mask .shape [- 1 ]])
425+ if bias is not None and bias .shape [- 1 ] != seqlen_k_rounded :
426+ if bias .shape [- 1 ] == 1 :
427+ bias = bias .expand (* bias .shape [:- 1 ], seqlen_k_rounded )
428+ else :
429+ bias = torch .nn .functional .pad (bias , [0 , seqlen_k_rounded - bias .shape [- 1 ]])
427430
428431 out_padded , softmax_lse , S_dmask = _wrapped_flash_dmattn_forward (
429432 q ,
@@ -443,7 +446,6 @@ def forward(
443446 ctx .is_causal = is_causal
444447 ctx .softcap = softcap
445448 ctx .deterministic = deterministic
446- ctx .seqlen_k_og = seqlen_k_og
447449
448450 out = out_padded [..., :head_size_og ]
449451
@@ -488,11 +490,8 @@ def backward(
488490 dk = dk [..., : dout .shape [- 1 ]]
489491 dv = dv [..., : dout .shape [- 1 ]]
490492
491- if ctx .seqlen_k_og % 8 != 0 :
492- dk = dk [:, : ctx .seqlen_k_og , :, :]
493- dv = dv [:, : ctx .seqlen_k_og , :, :]
494- if dbias is not None :
495- dbias = dbias [..., : ctx .seqlen_k_og ]
493+ if dbias is not None :
494+ dbias = dbias [..., : k .shape [1 ]]
496495
497496 return dq , dk , dv , None , dbias , None , None , None , None , None , None
498497
@@ -646,10 +645,10 @@ def flash_dmattn_func(
646645 key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim)
647646 value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim)
648647 attn_mask: torch.Tensor, optional. The attention mask boolean tensor of
649- shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0 }, seqlen_k) to apply to the attention scores.
648+ shape ({ batch_size|1} , {nheads|nheads_k|1}, {seqlen_q|1 }, { seqlen_k|1} ) to apply to the attention scores.
650649 If None, no mask is applied.
651650 attn_bias: torch.Tensor, optional. The attention bias float tensor of
652- shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0 }, seqlen_k) to add to the attention scores.
651+ shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|1 }, { seqlen_k|1} ) to add to the attention scores.
653652 If None, no bias is applied.
654653 softmax_scale: float. The scaling of QK^T before applying softmax.
655654 Default to 1 / sqrt(headdim).
0 commit comments