Skip to content

Commit 554e7e0

Browse files
committed
Align docs with sparse attention rename
Updates API reference to reflect the flash_sparse_attn branding so installation instructions, imports, and backend descriptions stay consistent with the renamed package.
1 parent 7e3faab commit 554e7e0

File tree

1 file changed

+40
-40
lines changed

1 file changed

+40
-40
lines changed

docs/api_reference.md

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Flash Dynamic Mask Attention API Reference
1+
# Flash Sparse Attention API Reference
22

33

44
## Overview
55

6-
Flash Dynamic Mask Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences.
6+
Flash Sparse Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences.
77

88

99
## Table of Contents
@@ -12,37 +12,37 @@ Flash Dynamic Mask Attention is a high-performance attention implementation that
1212
2. [Quick Start](#quick-start)
1313
3. [Backend Selection and Comparison](#backend-selection-and-comparison)
1414
4. [API Reference](#api-reference)
15-
- [CUDA Backend: flash_dmattn_func](#flash_dmattn_func-cuda-backend)
16-
- [Triton Backend: triton_dmattn_func](#triton_dmattn_func-triton-backend)
17-
- [Flex Backend: flex_dmattn_func](#flex_dmattn_func-flex-backend)
15+
- [CUDA Backend: flash_sparse_attn_func](#flash_sparse_attn_func-cuda-backend)
16+
- [Triton Backend: triton_sparse_attn_func](#triton_sparse_attn_func-triton-backend)
17+
- [Flex Backend: flex_sparse_attn_func](#flex_sparse_attn_func-flex-backend)
1818
5. [Integrations](#integrations)
1919
- [Transformers Integration](#transformers-integration)
2020
6. [Common Issues and Solutions](#common-issues-and-solutions)
2121

2222

2323
## Installation
2424

25-
Please refer to the [README](https://github.com/SmallDoges/flash-dmattn/blob/main/README.md#install) for detailed installation instructions.
25+
Please refer to the [README](https://github.com/SmallDoges/flash-sparse-attention/blob/main/README.md#install) for detailed installation instructions.
2626

2727
```bash
2828
# With CUDA backend
29-
pip install flash-dmattn
29+
pip install flash-sparse-attn
3030

3131
# Or install from source
3232
pip install -e .
3333

3434
# Triton/Flex only
35-
FLASH_DMATTN_SKIP_CUDA_BUILD=1 pip install -e .
35+
FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD=1 pip install -e .
3636
```
3737

3838

3939
## Quick Start
4040

41-
Use `flash_dmattn_func_auto` to automatically select the best available backend without manual checking.
41+
Use `flash_sparse_attn_func_auto` to automatically select the best available backend without manual checking.
4242

4343
```python
4444
import torch
45-
from flash_dmattn import flash_dmattn_func_auto
45+
from flash_sparse_attn import flash_sparse_attn_func_auto
4646

4747
# Prepare input tensors
4848
batch, seqlen, num_heads, head_dim = 2, 1024, 8, 64
@@ -51,27 +51,27 @@ k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device
5151
v = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda')
5252

5353
# Get attention function (auto-select backend, priority: cuda > triton > flex)
54-
attn_func = flash_dmattn_func_auto()
54+
attn_func = flash_sparse_attn_func_auto()
5555

5656
# Compute attention
5757
output = attn_func(q, k, v, is_causal=True)
5858
print(f"Output shape: {output.shape}") # (2, 1024, 8, 64)
5959

6060
# Or force a specific backend
61-
attn_func = flash_dmattn_func_auto(backend="cuda") # or "triton", "flex"
61+
attn_func = flash_sparse_attn_func_auto(backend="cuda") # or "triton", "flex"
6262
output = attn_func(q, k, v, is_causal=True)
6363
```
6464

6565
> [!NOTE]
66-
> `flash_dmattn_func_auto` returns a callable attention function, not the attention output.
66+
> `flash_sparse_attn_func_auto` returns a callable attention function, not the attention output.
6767
6868

6969
## Backend Selection and Comparison
7070

7171
### Check Available Backends
7272

7373
```python
74-
from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE
74+
from flash_sparse_attn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE
7575

7676
# List all available backends
7777
print(get_available_backends()) # e.g., ["cuda", "triton", "flex"]
@@ -101,19 +101,19 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL
101101
102102
### When to Use Each Backend
103103

104-
**CUDA Backend** ([details](#flash_dmattn_func-cuda-backend))
104+
**CUDA Backend** ([details](#flash_sparse_attn_func-cuda-backend))
105105
- ✅ Training workloads requiring full gradient support
106106
- ✅ Production inference requiring maximum performance
107107
- ✅ Applications needing deterministic behavior
108108
- ❌ Avoid: when custom CUDA extensions cannot be built
109109

110-
**Triton Backend** ([details](#triton_dmattn_func-triton-backend))
110+
**Triton Backend** ([details](#triton_sparse_attn_func-triton-backend))
111111
- ✅ Training when CUDA extension unavailable
112112
- ✅ Development and prototyping
113113
- ✅ Cross-platform compatibility needs
114114
- ✅ Good balance of performance and ease of installation
115115

116-
**Flex Backend** ([details](#flex_dmattn_func-flex-backend))
116+
**Flex Backend** ([details](#flex_sparse_attn_func-flex-backend))
117117
- ✅ Inference-only applications
118118
- ✅ Research with latest PyTorch features
119119
- ✅ Quick experimentation without custom builds
@@ -123,15 +123,15 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL
123123
### Import Available Functions
124124

125125
```python
126-
from flash_dmattn import (
126+
from flash_sparse_attn import (
127127
# Automatic backend selection
128128
get_available_backends,
129-
flash_dmattn_func_auto,
129+
flash_sparse_attn_func_auto,
130130

131131
# Backend-specific functions
132-
flash_dmattn_func, # CUDA backend
133-
triton_dmattn_func, # Triton backend
134-
flex_dmattn_func, # Flex backend
132+
flash_sparse_attn_func, # CUDA backend
133+
triton_sparse_attn_func, # Triton backend
134+
flex_sparse_attn_func, # Flex backend
135135

136136
# Backend availability flags
137137
CUDA_AVAILABLE,
@@ -140,20 +140,20 @@ from flash_dmattn import (
140140
)
141141

142142
# Transformers integration
143-
from flash_dmattn.integrations.flash_dynamic_mask_attention import (
144-
flash_dynamic_mask_attention_forward
143+
from flash_sparse_attn.integrations.flash_sparse_attention import (
144+
flash_sparse_attention_forward
145145
)
146146
```
147147

148148

149149
## API Reference
150150

151-
### flash_dmattn_func (CUDA backend)
151+
### flash_sparse_attn_func (CUDA backend)
152152

153153
Main attention function. Supports multi-head and grouped-query attention (when the number of KV heads is smaller than the number of Q heads). Requires the CUDA extension to be built and available.
154154

155155
```python
156-
def flash_dmattn_func(
156+
def flash_sparse_attn_func(
157157
query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
158158
key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
159159
value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
@@ -182,12 +182,12 @@ def flash_dmattn_func(
182182

183183
- output: (B, Q, H, D)
184184

185-
### triton_dmattn_func (Triton backend)
185+
### triton_sparse_attn_func (Triton backend)
186186

187187
Triton-based implementation that provides good performance without requiring custom CUDA kernels.
188188

189189
```python
190-
def triton_dmattn_func(
190+
def triton_sparse_attn_func(
191191
query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
192192
key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
193193
value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
@@ -198,12 +198,12 @@ def triton_dmattn_func(
198198
) -> torch.Tensor
199199
```
200200

201-
### flex_dmattn_func (Flex Attention backend)
201+
### flex_sparse_attn_func (Flex Attention backend)
202202

203203
Flex Attention-based implementation using PyTorch's native flex attention with dynamic masking support.
204204

205205
```python
206-
def flex_dmattn_func(
206+
def flex_sparse_attn_func(
207207
query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
208208
key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
209209
value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
@@ -221,13 +221,13 @@ def flex_dmattn_func(
221221

222222
Integration function for HuggingFace Transformers models that provides seamless flash dynamic mask attention support.
223223

224-
#### flash_dynamic_mask_attention_forward
224+
#### flash_sparse_attention_forward
225225

226226

227227
```python
228-
from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
228+
from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward
229229

230-
def flash_dynamic_mask_attention_forward(
230+
def flash_sparse_attention_forward(
231231
module: torch.nn.Module, # The attention module
232232
query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim)
233233
key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim)
@@ -254,7 +254,7 @@ def flash_dynamic_mask_attention_forward(
254254
- is_causal: Whether to apply causal mask
255255
- window_size: Size of window to keep
256256
- layer_idx: Layer index for logging
257-
- implementation: Implementation to use ("flash_dmattn" or None)
257+
- implementation: Implementation to use ("flash_sparse_attn" or None)
258258

259259
#### Returns
260260

@@ -268,7 +268,7 @@ import torch.nn as nn
268268
import torch.nn.functional as F
269269
from typing import Optional, Callable, tuple
270270
from transformers.cache_utils import Cache
271-
from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
271+
from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward
272272

273273
class DynamicMaskAttention(nn.Module):
274274
def __init__(self, config, layer_idx: Optional[int] = None):
@@ -332,7 +332,7 @@ class DynamicMaskAttention(nn.Module):
332332
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
333333

334334
# Choose attention implementation
335-
attention_interface: Callable = flash_dynamic_mask_attention_forward
335+
attention_interface: Callable = flash_sparse_attention_forward
336336

337337
attn_output, attn_weights = attention_interface(
338338
self,
@@ -362,7 +362,7 @@ This example shows:
362362

363363
```python
364364
try:
365-
from flash_dmattn import flash_dmattn_func_auto, get_available_backends
365+
from flash_sparse_attn import flash_sparse_attn_func_auto, get_available_backends
366366
print("✅ Imported successfully", get_available_backends())
367367
except ImportError as e:
368368
print(f"❌ Import failed: {e}")
@@ -385,10 +385,10 @@ except ImportError as e:
385385

386386
```python
387387
import torch
388-
from flash_dmattn import flash_dmattn_func_auto
388+
from flash_sparse_attn import flash_sparse_attn_func_auto
389389

390390
torch.autograd.set_detect_anomaly(True)
391-
attn = flash_dmattn_func_auto()
391+
attn = flash_sparse_attn_func_auto()
392392
output = attn(q, k, v, attn_mask=attn_mask, attn_bias=attn_bias, is_causal=True)
393393
if torch.isnan(output).any():
394394
print("⚠️ NaN detected in attention output")
@@ -404,7 +404,7 @@ def print_memory_stats():
404404
print(f"max alloc: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
405405

406406
print_memory_stats()
407-
attn = flash_dmattn_func_auto()
407+
attn = flash_sparse_attn_func_auto()
408408
output = attn(q, k, v)
409409
print_memory_stats()
410410
```

0 commit comments

Comments
 (0)