Update docstrings in attention functions for consistency#259
Update docstrings in attention functions for consistency#259LoserCheems merged 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the docstrings of the Triton attention interface functions to make the return-value descriptions more consistent, especially around the optional return_lse behavior.
Changes:
- Replaced separate
:return out:/:return lse:entries with a single conditional:returns:description for multiple attention wrapper functions. - Standardized wording describing the
return_lsetuple return behavior across dense/sparse/gated variants.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| :return out: Attention output tensor of shape [total_seqlen_q, num_heads_q, head_dim]. | ||
| :return lse: Logsumexp tensor of shape [total_seqlen_q, num_heads_q] if return_lse is True. Otherwise, not returned. | ||
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. |
There was a problem hiding this comment.
In varlen mode, the returned lse shape is documented as [total_seqlen_q, num_heads_q], but _flash_dense_attn_varlen_base_forward allocates/returns it as (num_heads_q, total_seqlen_q). Please update the docstring so the axes match the actual return value (or transpose lse before returning, but that would be an API change).
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. | |
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q]. |
|
|
||
| :return out: Attention output tensor of shape [total_seqlen_q, num_heads_q, head_dim]. | ||
| :return lse: Logsumexp tensor of shape [total_seqlen_q, num_heads_q] if return_lse is True. Otherwise, not returned. | ||
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. |
There was a problem hiding this comment.
In varlen mode, the docstring states lse has shape [total_seqlen_q, num_heads_q], but _flash_sparse_attn_varlen_base_forward returns lse with shape (num_heads_q, total_seqlen_q). Please correct the documented shape to match the actual output.
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. | |
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q]. |
|
|
||
| :return out: Attention output tensor of shape [total_seqlen_q, num_heads_q, head_dim]. | ||
| :return lse: Logsumexp tensor of shape [total_seqlen_q, num_heads_q] if return_lse is True. Otherwise, not returned. | ||
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. |
There was a problem hiding this comment.
The varlen gated attention path returns lse with shape (num_heads_q, total_seqlen_q) (see _flash_gated_attn_varlen_base_forward), but the docstring documents it as [total_seqlen_q, num_heads_q]. Please update the return description so consumers don’t swap axes.
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. | |
| :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q]. |
|
|
||
| :return out: Attention output tensor of shape [batch_size, seqlen_q, num_heads, head_dim]. | ||
| :return lse: Logsumexp tensor of shape [batch_size, num_heads, seqlen_q] if return_lse is True. Otherwise, not returned. | ||
| :returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. |
There was a problem hiding this comment.
This file (and most of the repo) uses reST field lists like :return ...: (e.g., flash_sparse_attn/ops/triton/activations.py:33-38), but these docstrings now introduce :returns:. Unless there’s a doc build requirement for :returns:, consider keeping :return:/:return <name>: to stay consistent with existing documentation style across the codebase.
| :returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. | |
| :return: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. |
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist