|
1 | | -from typing import Optional, Tuple |
| 1 | +from typing import Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 |
|
@@ -585,7 +585,7 @@ def flash_dense_attn_func( |
585 | 585 | softmax_scale: Optional[float] = None, |
586 | 586 | window_size: Tuple[Optional[int], Optional[int]] = (None, None), |
587 | 587 | return_lse: bool = False, |
588 | | -): |
| 588 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
589 | 589 | """ |
590 | 590 | Flash dense attention function that computes the attention output and optionally the logsumexp. |
591 | 591 |
|
@@ -624,7 +624,7 @@ def flash_dense_attn_varlen_func( |
624 | 624 | seqused_q: Optional[torch.Tensor] = None, |
625 | 625 | seqused_k: Optional[torch.Tensor] = None, |
626 | 626 | return_lse: bool = False, |
627 | | -): |
| 627 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
628 | 628 | """ |
629 | 629 | Flash dense attention function for variable-length sequences that computes the attention output and optionally the logsumexp. |
630 | 630 |
|
@@ -670,7 +670,7 @@ def flash_sparse_attn_func( |
670 | 670 | softmax_threshold: Optional[float] = None, |
671 | 671 | window_size: Tuple[Optional[int], Optional[int]] = (None, None), |
672 | 672 | return_lse: bool = False, |
673 | | -): |
| 673 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
674 | 674 | """ |
675 | 675 | Flash sparse attention function that computes the attention output and optionally the logsumexp. |
676 | 676 |
|
@@ -712,7 +712,7 @@ def flash_sparse_attn_varlen_func( |
712 | 712 | seqused_q: Optional[torch.Tensor] = None, |
713 | 713 | seqused_k: Optional[torch.Tensor] = None, |
714 | 714 | return_lse: bool = False, |
715 | | -): |
| 715 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
716 | 716 | """ |
717 | 717 | Flash sparse attention function for variable-length sequences that computes the attention output and optionally the logsumexp. |
718 | 718 |
|
@@ -765,7 +765,7 @@ def flash_gated_attn_func( |
765 | 765 | is_adapt_gate: bool = True, |
766 | 766 | window_size: Tuple[Optional[int], Optional[int]] = (None, None), |
767 | 767 | return_lse: bool = False, |
768 | | -): |
| 768 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
769 | 769 | """ |
770 | 770 | Flash gated attention function that computes the attention output and optionally the logsumexp. |
771 | 771 |
|
@@ -822,7 +822,7 @@ def flash_gated_attn_varlen_func( |
822 | 822 | seqused_q: Optional[torch.Tensor] = None, |
823 | 823 | seqused_k: Optional[torch.Tensor] = None, |
824 | 824 | return_lse: bool = False, |
825 | | -): |
| 825 | +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
826 | 826 | """ |
827 | 827 | Flash gated attention function for variable-length sequences that computes the attention output and optionally the logsumexp. |
828 | 828 |
|
|
0 commit comments