1- # Flash Dynamic Mask Attention API Reference
1+ # Flash Sparse Attention API Reference
22
33
44## Overview
55
6- Flash Dynamic Mask Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences.
6+ Flash Sparse Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences.
77
88
99## Table of Contents
@@ -12,37 +12,37 @@ Flash Dynamic Mask Attention is a high-performance attention implementation that
12122 . [ Quick Start] ( #quick-start )
13133 . [ Backend Selection and Comparison] ( #backend-selection-and-comparison )
14144 . [ API Reference] ( #api-reference )
15- - [ CUDA Backend: flash_dmattn_func ] ( #flash_dmattn_func -cuda-backend )
16- - [ Triton Backend: triton_dmattn_func ] ( #triton_dmattn_func -triton-backend )
17- - [ Flex Backend: flex_dmattn_func ] ( #flex_dmattn_func -flex-backend )
15+ - [ CUDA Backend: flash_sparse_attn_func ] ( #flash_sparse_attn_func -cuda-backend )
16+ - [ Triton Backend: triton_sparse_attn_func ] ( #triton_sparse_attn_func -triton-backend )
17+ - [ Flex Backend: flex_sparse_attn_func ] ( #flex_sparse_attn_func -flex-backend )
18185 . [ Integrations] ( #integrations )
1919 - [ Transformers Integration] ( #transformers-integration )
20206 . [ Common Issues and Solutions] ( #common-issues-and-solutions )
2121
2222
2323## Installation
2424
25- Please refer to the [ README] ( https://github.com/SmallDoges/flash-dmattn /blob/main/README.md#install ) for detailed installation instructions.
25+ Please refer to the [ README] ( https://github.com/SmallDoges/flash-sparse-attention /blob/main/README.md#install ) for detailed installation instructions.
2626
2727``` bash
2828# With CUDA backend
29- pip install flash-dmattn
29+ pip install flash-sparse-attn
3030
3131# Or install from source
3232pip install -e .
3333
3434# Triton/Flex only
35- FLASH_DMATTN_SKIP_CUDA_BUILD =1 pip install -e .
35+ FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD =1 pip install -e .
3636```
3737
3838
3939## Quick Start
4040
41- Use ` flash_dmattn_func_auto ` to automatically select the best available backend without manual checking.
41+ Use ` flash_sparse_attn_func_auto ` to automatically select the best available backend without manual checking.
4242
4343``` python
4444import torch
45- from flash_dmattn import flash_dmattn_func_auto
45+ from flash_sparse_attn import flash_sparse_attn_func_auto
4646
4747# Prepare input tensors
4848batch, seqlen, num_heads, head_dim = 2 , 1024 , 8 , 64
@@ -51,27 +51,27 @@ k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device
5151v = torch.randn(batch, seqlen, num_heads, head_dim, dtype = torch.bfloat16, device = ' cuda' )
5252
5353# Get attention function (auto-select backend, priority: cuda > triton > flex)
54- attn_func = flash_dmattn_func_auto ()
54+ attn_func = flash_sparse_attn_func_auto ()
5555
5656# Compute attention
5757output = attn_func(q, k, v, is_causal = True )
5858print (f " Output shape: { output.shape} " ) # (2, 1024, 8, 64)
5959
6060# Or force a specific backend
61- attn_func = flash_dmattn_func_auto (backend = " cuda" ) # or "triton", "flex"
61+ attn_func = flash_sparse_attn_func_auto (backend = " cuda" ) # or "triton", "flex"
6262output = attn_func(q, k, v, is_causal = True )
6363```
6464
6565> [ !NOTE]
66- > ` flash_dmattn_func_auto ` returns a callable attention function, not the attention output.
66+ > ` flash_sparse_attn_func_auto ` returns a callable attention function, not the attention output.
6767
6868
6969## Backend Selection and Comparison
7070
7171### Check Available Backends
7272
7373``` python
74- from flash_dmattn import get_available_backends, CUDA_AVAILABLE , TRITON_AVAILABLE , FLEX_AVAILABLE
74+ from flash_sparse_attn import get_available_backends, CUDA_AVAILABLE , TRITON_AVAILABLE , FLEX_AVAILABLE
7575
7676# List all available backends
7777print (get_available_backends()) # e.g., ["cuda", "triton", "flex"]
@@ -101,19 +101,19 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL
101101
102102### When to Use Each Backend
103103
104- ** CUDA Backend** ([ details] ( #flash_dmattn_func -cuda-backend ) )
104+ ** CUDA Backend** ([ details] ( #flash_sparse_attn_func -cuda-backend ) )
105105- ✅ Training workloads requiring full gradient support
106106- ✅ Production inference requiring maximum performance
107107- ✅ Applications needing deterministic behavior
108108- ❌ Avoid: when custom CUDA extensions cannot be built
109109
110- ** Triton Backend** ([ details] ( #triton_dmattn_func -triton-backend ) )
110+ ** Triton Backend** ([ details] ( #triton_sparse_attn_func -triton-backend ) )
111111- ✅ Training when CUDA extension unavailable
112112- ✅ Development and prototyping
113113- ✅ Cross-platform compatibility needs
114114- ✅ Good balance of performance and ease of installation
115115
116- ** Flex Backend** ([ details] ( #flex_dmattn_func -flex-backend ) )
116+ ** Flex Backend** ([ details] ( #flex_sparse_attn_func -flex-backend ) )
117117- ✅ Inference-only applications
118118- ✅ Research with latest PyTorch features
119119- ✅ Quick experimentation without custom builds
@@ -123,15 +123,15 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL
123123### Import Available Functions
124124
125125``` python
126- from flash_dmattn import (
126+ from flash_sparse_attn import (
127127 # Automatic backend selection
128128 get_available_backends,
129- flash_dmattn_func_auto ,
129+ flash_sparse_attn_func_auto ,
130130
131131 # Backend-specific functions
132- flash_dmattn_func , # CUDA backend
133- triton_dmattn_func , # Triton backend
134- flex_dmattn_func , # Flex backend
132+ flash_sparse_attn_func , # CUDA backend
133+ triton_sparse_attn_func , # Triton backend
134+ flex_sparse_attn_func , # Flex backend
135135
136136 # Backend availability flags
137137 CUDA_AVAILABLE ,
@@ -140,20 +140,20 @@ from flash_dmattn import (
140140)
141141
142142# Transformers integration
143- from flash_dmattn .integrations.flash_dynamic_mask_attention import (
144- flash_dynamic_mask_attention_forward
143+ from flash_sparse_attn .integrations.flash_sparse_attention import (
144+ flash_sparse_attention_forward
145145)
146146```
147147
148148
149149## API Reference
150150
151- ### flash_dmattn_func (CUDA backend)
151+ ### flash_sparse_attn_func (CUDA backend)
152152
153153Main attention function. Supports multi-head and grouped-query attention (when the number of KV heads is smaller than the number of Q heads). Requires the CUDA extension to be built and available.
154154
155155``` python
156- def flash_dmattn_func (
156+ def flash_sparse_attn_func (
157157 query : torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
158158 key : torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
159159 value : torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
@@ -182,12 +182,12 @@ def flash_dmattn_func(
182182
183183- output: (B, Q, H, D)
184184
185- # ## triton_dmattn_func (Triton backend)
185+ # ## triton_sparse_attn_func (Triton backend)
186186
187187Triton- based implementation that provides good performance without requiring custom CUDA kernels.
188188
189189```python
190- def triton_dmattn_func (
190+ def triton_sparse_attn_func (
191191 query : torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
192192 key : torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
193193 value : torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
@@ -198,12 +198,12 @@ def triton_dmattn_func(
198198) -> torch.Tensor
199199```
200200
201- # ## flex_dmattn_func (Flex Attention backend)
201+ # ## flex_sparse_attn_func (Flex Attention backend)
202202
203203Flex Attention- based implementation using PyTorch' s native flex attention with dynamic masking support.
204204
205205```python
206- def flex_dmattn_func (
206+ def flex_sparse_attn_func (
207207 query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
208208 key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
209209 value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
@@ -221,13 +221,13 @@ def flex_dmattn_func(
221221
222222Integration function for HuggingFace Transformers models that provides seamless flash dynamic mask attention support.
223223
224- # ### flash_dynamic_mask_attention_forward
224+ # ### flash_sparse_attention_forward
225225
226226
227227```python
228- from flash_dmattn .integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
228+ from flash_sparse_attn .integrations.flash_sparse_attention import flash_sparse_attention_forward
229229
230- def flash_dynamic_mask_attention_forward (
230+ def flash_sparse_attention_forward (
231231 module: torch.nn.Module, # The attention module
232232 query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim)
233233 key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim)
@@ -254,7 +254,7 @@ def flash_dynamic_mask_attention_forward(
254254 - is_causal: Whether to apply causal mask
255255 - window_size: Size of window to keep
256256 - layer_idx: Layer index for logging
257- - implementation: Implementation to use (" flash_dmattn " or None )
257+ - implementation: Implementation to use (" flash_sparse_attn " or None )
258258
259259# ### Returns
260260
@@ -268,7 +268,7 @@ import torch.nn as nn
268268import torch.nn.functional as F
269269from typing import Optional, Callable, tuple
270270from transformers.cache_utils import Cache
271- from flash_dmattn .integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
271+ from flash_sparse_attn .integrations.flash_sparse_attention import flash_sparse_attention_forward
272272
273273class DynamicMaskAttention(nn.Module):
274274 def __init__ (self , config, layer_idx: Optional[int ] = None ):
@@ -332,7 +332,7 @@ class DynamicMaskAttention(nn.Module):
332332 attn_bias = torch.exp(self .A * F.softplus(dt_states)).transpose(- 1 , - 2 ).to(hidden_states.dtype)
333333
334334 # Choose attention implementation
335- attention_interface: Callable = flash_dynamic_mask_attention_forward
335+ attention_interface: Callable = flash_sparse_attention_forward
336336
337337 attn_output, attn_weights = attention_interface(
338338 self ,
@@ -362,7 +362,7 @@ This example shows:
362362
363363```python
364364try :
365- from flash_dmattn import flash_dmattn_func_auto , get_available_backends
365+ from flash_sparse_attn import flash_sparse_attn_func_auto , get_available_backends
366366 print (" ✅ Imported successfully" , get_available_backends())
367367except ImportError as e:
368368 print (f " ❌ Import failed: { e} " )
@@ -385,10 +385,10 @@ except ImportError as e:
385385
386386```python
387387import torch
388- from flash_dmattn import flash_dmattn_func_auto
388+ from flash_sparse_attn import flash_sparse_attn_func_auto
389389
390390torch.autograd.set_detect_anomaly(True )
391- attn = flash_dmattn_func_auto ()
391+ attn = flash_sparse_attn_func_auto ()
392392output = attn(q, k, v, attn_mask = attn_mask, attn_bias = attn_bias, is_causal = True )
393393if torch.isnan(output).any():
394394 print (" ⚠️ NaN detected in attention output" )
@@ -404,7 +404,7 @@ def print_memory_stats():
404404 print (f " max alloc: { torch.cuda.max_memory_allocated() / 1e9 :.2f } GB " )
405405
406406print_memory_stats()
407- attn = flash_dmattn_func_auto ()
407+ attn = flash_sparse_attn_func_auto ()
408408output = attn(q, k, v)
409409print_memory_stats()
410410```
0 commit comments