Skip to content

Commit 8d6e8b5

Browse files
ardaunalfacebook-github-bot
authored andcommitted
Replace has_warp_spec with HAS_AUTO_WS env variable check (#221)
Summary: Triton release/3.3.x branch only supports AutoWS which makes the check has_warp_spec = hasattr(tl, "async_task") incorrect. Instead this PR adds an environment variable check HAS_AUTO_WS = os.getenv("ENABLE_AUTO_WS") to replace that. Pull Request resolved: #221 Reviewed By: xuzhao9 Differential Revision: D74487395 Pulled By: ardaunal fbshipit-source-id: c6b40ac2692df590937bda7d05e668fe3e556f6c
1 parent a3611b6 commit 8d6e8b5

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
WITH_COMPPIPE = os.getenv("ENABLE_COMPPIPE")
2828
PEEL_LAST = os.getenv("PEEL_LAST_ITER")
2929
WITH_TMA = os.getenv("WITH_TMA")
30+
HAS_AUTO_WS = os.getenv("ENABLE_AUTO_WS")
3031

3132
if HAS_TMA_DESC:
3233
print(
@@ -313,7 +314,6 @@ def _attn_fwd_inner_ws(
313314
# We don't run auto-tuning every time to keep the tutorial fast. Uncommenting
314315
# the code below and commenting out the equivalent parameters is convenient for
315316
# re-tuning.
316-
EXPLICIT_WARP_SPEC = hasattr(tl, "async_task")
317317
HAS_NEW_TMA = hasattr(triton, "set_allocator") and hasattr(tl, "make_tensor_descriptor")
318318
schedList = ["default", "FA_firstDot", "FA_secondDot"] if WITH_COMPPIPE else ["default"]
319319
# TODO: incorrect result with PEEL_LAST + FA_firstDot + WarpSpec + TMA
@@ -334,7 +334,7 @@ def _attn_fwd_inner_ws(
334334
num_buffers_warp_spec=0,
335335
num_consumer_groups=0,
336336
)
337-
if EXPLICIT_WARP_SPEC
337+
if HAS_AUTO_WS == "1"
338338
else triton.Config(
339339
{
340340
"BLOCK_M": BM,
@@ -367,7 +367,7 @@ def _attn_fwd_inner_ws(
367367
num_buffers_warp_spec=0,
368368
num_consumer_groups=0,
369369
)
370-
if EXPLICIT_WARP_SPEC
370+
if HAS_AUTO_WS == "1"
371371
else triton.Config(
372372
{
373373
"BLOCK_M": BM,
@@ -397,7 +397,7 @@ def _attn_fwd_inner_ws(
397397
reg_dec_producer=dec,
398398
reg_inc_consumer=inc,
399399
)
400-
if EXPLICIT_WARP_SPEC
400+
if HAS_AUTO_WS == "1"
401401
else triton.Config(
402402
{"BLOCK_M": BM, "BLOCK_N": BN, "ENABLE_TMA": False, "LOOP_SCHEDULE": sched},
403403
num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0,
@@ -431,7 +431,7 @@ def _attn_fwd_inner_ws(
431431
num_buffers_warp_spec=0,
432432
num_consumer_groups=0,
433433
)
434-
if EXPLICIT_WARP_SPEC
434+
if HAS_AUTO_WS == "1"
435435
else triton.Config(
436436
{
437437
"BLOCK_M": BM,
@@ -487,7 +487,7 @@ def _attn_fwd_inner_ws(
487487
reg_dec_producer=dec,
488488
reg_inc_consumer=inc,
489489
)
490-
if EXPLICIT_WARP_SPEC
490+
if HAS_AUTO_WS == "1"
491491
else triton.Config(
492492
{
493493
"BLOCK_M": BM,
@@ -528,7 +528,7 @@ def _attn_fwd_inner_ws(
528528
reg_dec_producer=dec,
529529
reg_inc_consumer=inc,
530530
)
531-
if EXPLICIT_WARP_SPEC
531+
if HAS_AUTO_WS == "1"
532532
else triton.Config(
533533
{
534534
"BLOCK_M": BM,

0 commit comments

Comments
 (0)