Skip to content

Commit 0fdf0c0

Browse files
committed
Add return type annotations for attention functions
1 parent c79c27e commit 0fdf0c0

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

flash_sparse_attn/ops/triton/interface.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple
1+
from typing import Optional, Tuple, Union
22

33
import torch
44

@@ -585,7 +585,7 @@ def flash_dense_attn_func(
585585
softmax_scale: Optional[float] = None,
586586
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
587587
return_lse: bool = False,
588-
):
588+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
589589
"""
590590
Flash dense attention function that computes the attention output and optionally the logsumexp.
591591
@@ -624,7 +624,7 @@ def flash_dense_attn_varlen_func(
624624
seqused_q: Optional[torch.Tensor] = None,
625625
seqused_k: Optional[torch.Tensor] = None,
626626
return_lse: bool = False,
627-
):
627+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
628628
"""
629629
Flash dense attention function for variable-length sequences that computes the attention output and optionally the logsumexp.
630630
@@ -670,7 +670,7 @@ def flash_sparse_attn_func(
670670
softmax_threshold: Optional[float] = None,
671671
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
672672
return_lse: bool = False,
673-
):
673+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
674674
"""
675675
Flash sparse attention function that computes the attention output and optionally the logsumexp.
676676
@@ -712,7 +712,7 @@ def flash_sparse_attn_varlen_func(
712712
seqused_q: Optional[torch.Tensor] = None,
713713
seqused_k: Optional[torch.Tensor] = None,
714714
return_lse: bool = False,
715-
):
715+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
716716
"""
717717
Flash sparse attention function for variable-length sequences that computes the attention output and optionally the logsumexp.
718718
@@ -765,7 +765,7 @@ def flash_gated_attn_func(
765765
is_adapt_gate: bool = True,
766766
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
767767
return_lse: bool = False,
768-
):
768+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
769769
"""
770770
Flash gated attention function that computes the attention output and optionally the logsumexp.
771771
@@ -822,7 +822,7 @@ def flash_gated_attn_varlen_func(
822822
seqused_q: Optional[torch.Tensor] = None,
823823
seqused_k: Optional[torch.Tensor] = None,
824824
return_lse: bool = False,
825-
):
825+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
826826
"""
827827
Flash gated attention function for variable-length sequences that computes the attention output and optionally the logsumexp.
828828

0 commit comments

Comments
 (0)