Skip to content

Commit 37098b7

Browse files
LoserCheemsCopilot
andauthored
Update flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py
Co-authored-by: Copilot <[email protected]>
1 parent 406f673 commit 37098b7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.T
182182
"""
183183
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
184184
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
185-
# NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
185+
# NOTE: Similar to the `.item()` in prepare_fdma_kwargs_from_position_ids, with torch compile,
186186
# this might cause a graph break
187187
max_seqlen_in_batch = seqlens_in_batch.max().item()
188188
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))

0 commit comments

Comments
 (0)