Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions flash_sparse_attn/ops/triton/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,7 @@ def flash_dense_attn_func(
:param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

:return out: Attention output tensor of shape [batch_size, seqlen_q, num_heads, head_dim].
:return lse: Logsumexp tensor of shape [batch_size, num_heads, seqlen_q] if return_lse is True. Otherwise, not returned.
:returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file (and most of the repo) uses reST field lists like :return ...: (e.g., flash_sparse_attn/ops/triton/activations.py:33-38), but these docstrings now introduce :returns:. Unless there’s a doc build requirement for :returns:, consider keeping :return:/:return <name>: to stay consistent with existing documentation style across the codebase.

Suggested change
:returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].
:return: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].

Copilot uses AI. Check for mistakes.
"""
return FlashDenseAttnFunc.apply(
query,
Expand Down Expand Up @@ -643,8 +642,7 @@ def flash_dense_attn_varlen_func(
:param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

:return out: Attention output tensor of shape [total_seqlen_q, num_heads_q, head_dim].
:return lse: Logsumexp tensor of shape [total_seqlen_q, num_heads_q] if return_lse is True. Otherwise, not returned.
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In varlen mode, the returned lse shape is documented as [total_seqlen_q, num_heads_q], but _flash_dense_attn_varlen_base_forward allocates/returns it as (num_heads_q, total_seqlen_q). Please update the docstring so the axes match the actual return value (or transpose lse before returning, but that would be an API change).

Suggested change
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q].

Copilot uses AI. Check for mistakes.
"""
return FlashDenseAttnVarlenFunc.apply(
query,
Expand Down Expand Up @@ -685,8 +683,7 @@ def flash_sparse_attn_func(
:param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

:return out: Attention output tensor of shape [batch_size, seqlen_q, num_heads, head_dim].
:return lse: Logsumexp tensor of shape [batch_size, num_heads, seqlen_q] if return_lse is True. Otherwise, not returned.
:returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].
"""
return FlashSparseAttnFunc.apply(
query,
Expand Down Expand Up @@ -734,8 +731,7 @@ def flash_sparse_attn_varlen_func(
:param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

:return out: Attention output tensor of shape [total_seqlen_q, num_heads_q, head_dim].
:return lse: Logsumexp tensor of shape [total_seqlen_q, num_heads_q] if return_lse is True. Otherwise, not returned.
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In varlen mode, the docstring states lse has shape [total_seqlen_q, num_heads_q], but _flash_sparse_attn_varlen_base_forward returns lse with shape (num_heads_q, total_seqlen_q). Please correct the documented shape to match the actual output.

Suggested change
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q].

Copilot uses AI. Check for mistakes.
"""
return FlashSparseAttnVarlenFunc.apply(
query,
Expand Down Expand Up @@ -787,8 +783,7 @@ def flash_gated_attn_func(
:param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

:return out: Attention output tensor of shape [batch_size, seqlen_q, num_heads, head_dim].
:return lse: Logsumexp tensor of shape [batch_size, num_heads, seqlen_q] if return_lse is True. Otherwise, not returned.
:returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q].
"""
return FlashGatedAttnFunc.apply(
query,
Expand Down Expand Up @@ -851,8 +846,7 @@ def flash_gated_attn_varlen_func(
:param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking.
:param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out.

:return out: Attention output tensor of shape [total_seqlen_q, num_heads_q, head_dim].
:return lse: Logsumexp tensor of shape [total_seqlen_q, num_heads_q] if return_lse is True. Otherwise, not returned.
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The varlen gated attention path returns lse with shape (num_heads_q, total_seqlen_q) (see _flash_gated_attn_varlen_base_forward), but the docstring documents it as [total_seqlen_q, num_heads_q]. Please update the return description so consumers don’t swap axes.

Suggested change
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q].
:returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [num_heads_q, total_seqlen_q].

Copilot uses AI. Check for mistakes.
"""
return FlashGatedAttnVarlenFunc.apply(
query,
Expand Down
Loading