-
Notifications
You must be signed in to change notification settings - Fork 53
Update docstrings in attention functions for consistency #259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -597,8 +597,7 @@ def flash_dense_attn_func( | |||||
| :param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. | ||||||
| :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. | ||||||
|
|
||||||
| :return out: Attention output tensor of shape [batch_size, seqlen_q, num_heads, head_dim]. | ||||||
| :return lse: Logsumexp tensor of shape [batch_size, num_heads, seqlen_q] if return_lse is True. Otherwise, not returned. | ||||||
| :returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. | ||||||
| """ | ||||||
| return FlashDenseAttnFunc.apply( | ||||||
| query, | ||||||
|
|
@@ -643,8 +642,7 @@ def flash_dense_attn_varlen_func( | |||||
| :param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking. | ||||||
| :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. | ||||||
|
|
||||||
| :return out: Attention output tensor of shape [total_seqlen_q, num_heads_q, head_dim]. | ||||||
| :return lse: Logsumexp tensor of shape [total_seqlen_q, num_heads_q] if return_lse is True. Otherwise, not returned. | ||||||
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. | ||||||
|
||||||
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. | |
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q]. |
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In varlen mode, the docstring states lse has shape [total_seqlen_q, num_heads_q], but _flash_sparse_attn_varlen_base_forward returns lse with shape (num_heads_q, total_seqlen_q). Please correct the documented shape to match the actual output.
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. | |
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q]. |
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The varlen gated attention path returns lse with shape (num_heads_q, total_seqlen_q) (see _flash_gated_attn_varlen_base_forward), but the docstring documents it as [total_seqlen_q, num_heads_q]. Please update the return description so consumers don’t swap axes.
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. | |
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file (and most of the repo) uses reST field lists like
:return ...:(e.g.,flash_sparse_attn/ops/triton/activations.py:33-38), but these docstrings now introduce:returns:. Unless there’s a doc build requirement for:returns:, consider keeping:return:/:return <name>:to stay consistent with existing documentation style across the codebase.