27
27
WITH_COMPPIPE = os .getenv ("ENABLE_COMPPIPE" )
28
28
PEEL_LAST = os .getenv ("PEEL_LAST_ITER" )
29
29
WITH_TMA = os .getenv ("WITH_TMA" )
30
+ HAS_AUTO_WS = os .getenv ("ENABLE_AUTO_WS" )
30
31
31
32
if HAS_TMA_DESC :
32
33
print (
@@ -313,7 +314,6 @@ def _attn_fwd_inner_ws(
313
314
# We don't run auto-tuning every time to keep the tutorial fast. Uncommenting
314
315
# the code below and commenting out the equivalent parameters is convenient for
315
316
# re-tuning.
316
- EXPLICIT_WARP_SPEC = hasattr (tl , "async_task" )
317
317
HAS_NEW_TMA = hasattr (triton , "set_allocator" ) and hasattr (tl , "make_tensor_descriptor" )
318
318
schedList = ["default" , "FA_firstDot" , "FA_secondDot" ] if WITH_COMPPIPE else ["default" ]
319
319
# TODO: incorrect result with PEEL_LAST + FA_firstDot + WarpSpec + TMA
@@ -334,7 +334,7 @@ def _attn_fwd_inner_ws(
334
334
num_buffers_warp_spec = 0 ,
335
335
num_consumer_groups = 0 ,
336
336
)
337
- if EXPLICIT_WARP_SPEC
337
+ if HAS_AUTO_WS == "1"
338
338
else triton .Config (
339
339
{
340
340
"BLOCK_M" : BM ,
@@ -367,7 +367,7 @@ def _attn_fwd_inner_ws(
367
367
num_buffers_warp_spec = 0 ,
368
368
num_consumer_groups = 0 ,
369
369
)
370
- if EXPLICIT_WARP_SPEC
370
+ if HAS_AUTO_WS == "1"
371
371
else triton .Config (
372
372
{
373
373
"BLOCK_M" : BM ,
@@ -397,7 +397,7 @@ def _attn_fwd_inner_ws(
397
397
reg_dec_producer = dec ,
398
398
reg_inc_consumer = inc ,
399
399
)
400
- if EXPLICIT_WARP_SPEC
400
+ if HAS_AUTO_WS == "1"
401
401
else triton .Config (
402
402
{"BLOCK_M" : BM , "BLOCK_N" : BN , "ENABLE_TMA" : False , "LOOP_SCHEDULE" : sched },
403
403
num_stages = 2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0 ,
@@ -431,7 +431,7 @@ def _attn_fwd_inner_ws(
431
431
num_buffers_warp_spec = 0 ,
432
432
num_consumer_groups = 0 ,
433
433
)
434
- if EXPLICIT_WARP_SPEC
434
+ if HAS_AUTO_WS == "1"
435
435
else triton .Config (
436
436
{
437
437
"BLOCK_M" : BM ,
@@ -487,7 +487,7 @@ def _attn_fwd_inner_ws(
487
487
reg_dec_producer = dec ,
488
488
reg_inc_consumer = inc ,
489
489
)
490
- if EXPLICIT_WARP_SPEC
490
+ if HAS_AUTO_WS == "1"
491
491
else triton .Config (
492
492
{
493
493
"BLOCK_M" : BM ,
@@ -528,7 +528,7 @@ def _attn_fwd_inner_ws(
528
528
reg_dec_producer = dec ,
529
529
reg_inc_consumer = inc ,
530
530
)
531
- if EXPLICIT_WARP_SPEC
531
+ if HAS_AUTO_WS == "1"
532
532
else triton .Config (
533
533
{
534
534
"BLOCK_M" : BM ,
0 commit comments