Skip to content

Commit a0ed87d

Browse files
committed
Aligns benchmarks with sparse attn imports
Updates benchmark integrations to load the flash_sparse_attn implementations so the renamed package continues to back the CUDA, Triton, and Flex runs. Renames the availability guards and status messages to keep diagnostic output aligned with the new module namespace.
1 parent ac95f25 commit a0ed87d

File tree

4 files changed

+80
-80
lines changed

4 files changed

+80
-80
lines changed

benchmarks/backward_equivalence.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,33 @@
2121

2222
# Import the compiled CUDA extension
2323
try:
24-
from flash_dmattn.flash_dmattn_interface import flash_dmattn_func
25-
print("✅ Successfully imported flash_dmattn interface")
24+
from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func
25+
print("✅ Successfully imported flash_sparse_attn interface")
2626
except ImportError as e:
27-
print(f"❌ Failed to import flash_dmattn interface: {e}")
27+
print(f"❌ Failed to import flash_sparse_attn interface: {e}")
2828
print("Please make sure the package is properly installed with: pip install .")
2929
# Don't exit here, just warn
30-
flash_dmattn_func = None
30+
flash_sparse_attn_func = None
3131

3232
# Import the Triton implementation
3333
try:
34-
from flash_dmattn.flash_dmattn_triton import triton_dmattn_func
35-
print("✅ Successfully imported flash_dmattn_triton")
34+
from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func
35+
print("✅ Successfully imported flash_sparse_attn_triton")
3636
except ImportError as e:
37-
print(f"❌ Failed to import flash_dmattn_triton: {e}")
37+
print(f"❌ Failed to import flash_sparse_attn_triton: {e}")
3838
print("Please make sure the Triton implementation is available.")
3939
# Don't exit here, just warn
40-
triton_dmattn_func = None
40+
triton_sparse_attn_func = None
4141

4242
# Import the Flex Attention implementation
4343
try:
44-
from flash_dmattn.flash_dmattn_flex import flex_dmattn_func
45-
print("✅ Successfully imported flash_dmattn_flex")
44+
from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func
45+
print("✅ Successfully imported flash_sparse_attn_flex")
4646
except ImportError as e:
47-
print(f"❌ Failed to import flash_dmattn_flex: {e}")
47+
print(f"❌ Failed to import flash_sparse_attn_flex: {e}")
4848
print("Please make sure the Flex Attention implementation is available.")
4949
# Don't exit here, just warn
50-
flex_dmattn_func = None
50+
flex_sparse_attn_func = None
5151

5252

5353
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -189,7 +189,7 @@ def dynamic_mask_attention_cuda(
189189
Returns:
190190
tuple: (attn_outputs, dq, dk, dv, dbias)
191191
"""
192-
if flash_dmattn_func is None:
192+
if flash_sparse_attn_func is None:
193193
raise ImportError("CUDA implementation not available")
194194

195195
query_states_leaf = query_states
@@ -210,8 +210,8 @@ def dynamic_mask_attention_cuda(
210210
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
211211
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
212212

213-
# Call the flash_dmattn_func interface
214-
attn_outputs = flash_dmattn_func(
213+
# Call the flash_sparse_attn_func interface
214+
attn_outputs = flash_sparse_attn_func(
215215
query=query_states,
216216
key=key_states,
217217
value=value_states,
@@ -256,7 +256,7 @@ def dynamic_mask_attention_triton(
256256
Returns:
257257
tuple: (attn_outputs, dq, dk, dv, dbias)
258258
"""
259-
if triton_dmattn_func is None:
259+
if triton_sparse_attn_func is None:
260260
raise RuntimeError("Triton implementation not available")
261261

262262
_, num_heads, _, _ = query_states.shape
@@ -288,7 +288,7 @@ def dynamic_mask_attention_triton(
288288
value_states = value_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]
289289

290290
# Call the Triton implementation
291-
attn_outputs = triton_dmattn_func(
291+
attn_outputs = triton_sparse_attn_func(
292292
query=query_states,
293293
key=key_states,
294294
value=value_states,
@@ -330,7 +330,7 @@ def dynamic_mask_attention_flex(
330330
Returns:
331331
tuple: (attn_outputs, dq, dk, dv, dbias)
332332
"""
333-
if flex_dmattn_func is None:
333+
if flex_sparse_attn_func is None:
334334
raise RuntimeError("Flex Attention implementation not available")
335335

336336
_, num_heads, _, _ = query_states.shape
@@ -359,7 +359,7 @@ def dynamic_mask_attention_flex(
359359
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
360360

361361
# Call the Flex Attention implementation
362-
attn_outputs = flex_dmattn_func(
362+
attn_outputs = flex_sparse_attn_func(
363363
query_states,
364364
key_states,
365365
value_states,
@@ -474,7 +474,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
474474
print("🚀" + "=" * 76 + "🚀")
475475

476476
# Check if CUDA implementation is available
477-
if flash_dmattn_func is None:
477+
if flash_sparse_attn_func is None:
478478
print("❌ CUDA implementation not available, skipping test.")
479479
return False
480480

@@ -734,7 +734,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95):
734734
print("🚀" + "=" * 76 + "🚀")
735735

736736
# Check if Triton implementation is available
737-
if triton_dmattn_func is None:
737+
if triton_sparse_attn_func is None:
738738
print("❌ Triton implementation not available, skipping test.")
739739
return False
740740

benchmarks/backward_performance.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,33 +28,33 @@
2828

2929
# Import the compiled CUDA extension
3030
try:
31-
from flash_dmattn.flash_dmattn_interface import flash_dmattn_func
32-
print("✅ Successfully imported flash_dmattn interface")
31+
from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func
32+
print("✅ Successfully imported flash_sparse_attn interface")
3333
except ImportError as e:
34-
print(f"❌ Failed to import flash_dmattn interface: {e}")
34+
print(f"❌ Failed to import flash_sparse_attn interface: {e}")
3535
print("Please make sure the package is properly installed with: pip install .")
3636
# Don't exit here, just warn
37-
flash_dmattn_func = None
37+
flash_sparse_attn_func = None
3838

3939
# Import the Triton implementation
4040
try:
41-
from flash_dmattn.flash_dmattn_triton import triton_dmattn_func
42-
print("✅ Successfully imported flash_dmattn_triton")
41+
from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func
42+
print("✅ Successfully imported flash_sparse_attn_triton")
4343
except ImportError as e:
44-
print(f"❌ Failed to import flash_dmattn_triton: {e}")
44+
print(f"❌ Failed to import flash_sparse_attn_triton: {e}")
4545
print("Please make sure the Triton implementation is available.")
4646
# Don't exit here, just warn
47-
triton_dmattn_func = None
47+
triton_sparse_attn_func = None
4848

4949
# Import the Flex Attention implementation
5050
try:
51-
from flash_dmattn.flash_dmattn_flex import flex_dmattn_func
52-
print("✅ Successfully imported flash_dmattn_flex")
51+
from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func
52+
print("✅ Successfully imported flash_sparse_attn_flex")
5353
except ImportError as e:
54-
print(f"❌ Failed to import flash_dmattn_flex: {e}")
54+
print(f"❌ Failed to import flash_sparse_attn_flex: {e}")
5555
print("Please make sure the Flex Attention implementation is available.")
5656
# Don't exit here, just warn
57-
flex_dmattn_func = None
57+
flex_sparse_attn_func = None
5858

5959

6060
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -207,7 +207,7 @@ def dynamic_mask_attention_backward_cuda(
207207
Returns:
208208
tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0)
209209
"""
210-
if flash_dmattn_func is None:
210+
if flash_sparse_attn_func is None:
211211
return "Not Available", 0
212212

213213
attn_bias, attn_mask = prepare_mask(
@@ -223,7 +223,7 @@ def dynamic_mask_attention_backward_cuda(
223223
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
224224

225225
try:
226-
attn_outputs = flash_dmattn_func(
226+
attn_outputs = flash_sparse_attn_func(
227227
query=query_states,
228228
key=key_states,
229229
value=value_states,
@@ -277,7 +277,7 @@ def dynamic_mask_attention_backward_triton(
277277
Returns:
278278
tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0)
279279
"""
280-
if triton_dmattn_func is None:
280+
if triton_sparse_attn_func is None:
281281
return "Not Available", 0
282282

283283
_, num_heads, _, _ = query_states.shape
@@ -305,7 +305,7 @@ def dynamic_mask_attention_backward_triton(
305305
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
306306

307307
try:
308-
attn_outputs = triton_dmattn_func(
308+
attn_outputs = triton_sparse_attn_func(
309309
query=query_states,
310310
key=key_states,
311311
value=value_states,
@@ -356,7 +356,7 @@ def dynamic_mask_attention_backward_flex(
356356
Returns:
357357
tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0)
358358
"""
359-
if flex_dmattn_func is None:
359+
if flex_sparse_attn_func is None:
360360
return "Not Available", 0
361361

362362
_, num_heads, _, _ = query_states.shape
@@ -384,7 +384,7 @@ def dynamic_mask_attention_backward_flex(
384384
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
385385

386386
try:
387-
attn_outputs = flex_dmattn_func(
387+
attn_outputs = flex_sparse_attn_func(
388388
query_states,
389389
key_states,
390390
value_states,

benchmarks/forward_equivalence.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,33 @@
2121

2222
# Import the compiled CUDA extension
2323
try:
24-
from flash_dmattn.flash_dmattn_interface import flash_dmattn_func
25-
print("✅ Successfully imported flash_dmattn interface")
24+
from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func
25+
print("✅ Successfully imported flash_sparse_attn interface")
2626
except ImportError as e:
27-
print(f"❌ Failed to import flash_dmattn interface: {e}")
27+
print(f"❌ Failed to import flash_sparse_attn interface: {e}")
2828
print("Please make sure the package is properly installed with: pip install .")
2929
# Don't exit here, just warn
30-
flash_dmattn_func = None
30+
flash_sparse_attn_func = None
3131

3232
# Import the Triton implementation
3333
try:
34-
from flash_dmattn.flash_dmattn_triton import triton_dmattn_func
35-
print("✅ Successfully imported flash_dmattn_triton")
34+
from flash_sparse_attn.flash_sparse_attn_triton import triton_sparse_attn_func
35+
print("✅ Successfully imported flash_sparse_attn_triton")
3636
except ImportError as e:
37-
print(f"❌ Failed to import flash_dmattn_triton: {e}")
37+
print(f"❌ Failed to import flash_sparse_attn_triton: {e}")
3838
print("Please make sure the Triton implementation is available.")
3939
# Don't exit here, just warn
40-
triton_dmattn_func = None
40+
triton_sparse_attn_func = None
4141

4242
# Import the Flex Attention implementation
4343
try:
44-
from flash_dmattn.flash_dmattn_flex import flex_dmattn_func
45-
print("✅ Successfully imported flash_dmattn_flex")
44+
from flash_sparse_attn.flash_sparse_attn_flex import flex_sparse_attn_func
45+
print("✅ Successfully imported flash_sparse_attn_flex")
4646
except ImportError as e:
47-
print(f"❌ Failed to import flash_dmattn_flex: {e}")
47+
print(f"❌ Failed to import flash_sparse_attn_flex: {e}")
4848
print("Please make sure the Flex Attention implementation is available.")
4949
# Don't exit here, just warn
50-
flex_dmattn_func = None
50+
flex_sparse_attn_func = None
5151

5252

5353
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -181,8 +181,8 @@ def dynamic_mask_attention_cuda(
181181
Returns:
182182
attn_outputs: [batch_size, query_len, num_heads, head_dim]
183183
"""
184-
if flash_dmattn_func is None:
185-
raise RuntimeError("flash_dmattn_func not available")
184+
if flash_sparse_attn_func is None:
185+
raise RuntimeError("flash_sparse_attn_func not available")
186186

187187
attn_bias, attn_mask = prepare_mask(
188188
query_states,
@@ -196,8 +196,8 @@ def dynamic_mask_attention_cuda(
196196
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
197197
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
198198

199-
# Call the flash_dmattn_func interface
200-
attn_outputs = flash_dmattn_func(
199+
# Call the flash_sparse_attn_func interface
200+
attn_outputs = flash_sparse_attn_func(
201201
query_states,
202202
key_states,
203203
value_states,
@@ -239,7 +239,7 @@ def dynamic_mask_attention_triton(
239239
Returns:
240240
attn_outputs: [batch_size, query_len, num_heads, head_dim]
241241
"""
242-
if triton_dmattn_func is None:
242+
if triton_sparse_attn_func is None:
243243
raise RuntimeError("Triton implementation not available")
244244

245245
_, num_heads, _, _ = query_states.shape
@@ -267,7 +267,7 @@ def dynamic_mask_attention_triton(
267267
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
268268

269269
# Call the Triton implementation
270-
attn_outputs = triton_dmattn_func(
270+
attn_outputs = triton_sparse_attn_func(
271271
query_states,
272272
key_states,
273273
value_states,
@@ -306,7 +306,7 @@ def dynamic_mask_attention_flex(
306306
Returns:
307307
attn_outputs: [batch_size, query_len, num_heads, head_dim]
308308
"""
309-
if flex_dmattn_func is None:
309+
if flex_sparse_attn_func is None:
310310
raise RuntimeError("Flex Attention implementation not available")
311311

312312
_, num_heads, _, _ = query_states.shape
@@ -334,7 +334,7 @@ def dynamic_mask_attention_flex(
334334
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
335335

336336
# Call the Flex Attention implementation
337-
attn_outputs = flex_dmattn_func(
337+
attn_outputs = flex_sparse_attn_func(
338338
query_states,
339339
key_states,
340340
value_states,
@@ -446,7 +446,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95):
446446
print("🚀" + "=" * 76 + "🚀")
447447

448448
# Check if CUDA implementation is available
449-
if flash_dmattn_func is None:
449+
if flash_sparse_attn_func is None:
450450
print("❌ CUDA implementation not available, skipping test.")
451451
return False
452452

@@ -653,7 +653,7 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95):
653653
print("🔬 Testing Forward Pass Equivalence: Python vs Triton 🔬")
654654
print("🔥" + "=" * 76 + "🔥")
655655

656-
if triton_dmattn_func is None:
656+
if triton_sparse_attn_func is None:
657657
print("❌ Triton implementation not available, skipping Triton tests")
658658
return False
659659

@@ -859,7 +859,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95):
859859
print("🔬 Testing Forward Pass Equivalence: Python vs Flex Attention 🔬")
860860
print("🌟" + "=" * 76 + "🌟")
861861

862-
if flex_dmattn_func is None:
862+
if flex_sparse_attn_func is None:
863863
print("❌ Flex Attention implementation not available, skipping Flex Attention tests")
864864
return False
865865

0 commit comments

Comments
 (0)