Add return type annotations for attention functions#260
Conversation
There was a problem hiding this comment.
Pull request overview
Adds explicit return type annotations to the public Triton attention wrapper functions in interface.py to improve readability and enable static type checking when return_lse toggles between returning out vs (out, lse).
Changes:
- Import
Unionfromtypingfor return annotations. - Annotate return types for dense/sparse/gated attention wrappers (fixed-length and varlen variants) as
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]].
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| seqused_k: Optional[torch.Tensor] = None, | ||
| return_lse: bool = False, | ||
| ): | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
There was a problem hiding this comment.
Docstring return-shape for lse in this varlen API looks incorrect: the underlying varlen forward allocates/returns lse as (num_heads_q, total_seqlen_q) (heads-first), not (total_seqlen_q, num_heads_q). Please update the :returns: section to match the actual tensor layout so users don’t mis-handle dimensions.
| seqused_k: Optional[torch.Tensor] = None, | ||
| return_lse: bool = False, | ||
| ): | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
There was a problem hiding this comment.
Docstring return-shape for lse in this varlen API looks incorrect: the underlying varlen forward allocates/returns lse as (num_heads_q, total_seqlen_q) (heads-first), not (total_seqlen_q, num_heads_q). Please update the :returns: section to match the actual tensor layout so users don’t mis-handle dimensions.
| seqused_k: Optional[torch.Tensor] = None, | ||
| return_lse: bool = False, | ||
| ): | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
There was a problem hiding this comment.
Docstring return-shape for lse in this varlen API looks incorrect: the underlying varlen forward allocates/returns lse as (num_heads_q, total_seqlen_q) (heads-first), not (total_seqlen_q, num_heads_q). Please update the :returns: section to match the actual tensor layout so users don’t mis-handle dimensions.
| window_size: Tuple[Optional[int], Optional[int]] = (None, None), | ||
| return_lse: bool = False, | ||
| ): | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
There was a problem hiding this comment.
The new return type annotation is a broad Union, so type checkers can’t infer the return type from return_lse (e.g., callers still see Tensor | tuple[Tensor, Tensor] even in if return_lse: branches). If the goal is improved static checking, consider adding @overload signatures using Literal[False]/Literal[True] for return_lse (and applying the same pattern to the other *_attn*_func wrappers).
Enhance type safety by adding return type annotations to various attention functions. This change improves code clarity and helps with static type checking. No functional changes were made.