@@ -291,7 +291,7 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
291
291
configs = [
292
292
make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce )
293
293
for BM in [256 ]
294
- for BN in [128 ]
294
+ for BN in [64 , 128 ]
295
295
for s in NUM_STAGES_OPTIONS
296
296
for w in [4 ]
297
297
for subtile in [True ]
@@ -318,6 +318,13 @@ def prune_invalid_configs(configs, named_args, **kwargs):
318
318
return [conf for conf in configs if conf .kwargs .get ("BLOCK_M" , 0 ) <= N_CTX ]
319
319
320
320
321
+ def prune_persistent_configs (configs , named_args , ** kwargs ):
322
+ N_CTX = kwargs ["N_CTX" ]
323
+ # Filter out configs based on desired BLOCK_n
324
+ TARGET_BLOCK_N = 64 if N_CTX == 128 else 128
325
+ return [conf for conf in configs if conf .kwargs .get ("BLOCK_N" , 0 ) == TARGET_BLOCK_N ]
326
+
327
+
321
328
@triton .jit
322
329
def _maybe_make_tensor_desc (desc_or_ptr , shape , strides , block_shape ):
323
330
if isinstance (desc_or_ptr , tl .tensor_descriptor ):
@@ -399,7 +406,7 @@ def _attn_fwd_tma_dp(
399
406
desc_o ,
400
407
pid ,
401
408
off_hz ,
402
- N_CTX , #
409
+ N_CTX : tl . constexpr , #
403
410
HEAD_DIM : tl .constexpr , #
404
411
BLOCK_M : tl .constexpr , #
405
412
BLOCK_N : tl .constexpr , #
@@ -543,7 +550,7 @@ def _attn_fwd(
543
550
desc_k ,
544
551
desc_v ,
545
552
desc_o ,
546
- N_CTX , #
553
+ N_CTX : tl . constexpr , #
547
554
HEAD_DIM : tl .constexpr , #
548
555
BLOCK_M : tl .constexpr , #
549
556
BLOCK_N : tl .constexpr , #
@@ -585,7 +592,7 @@ def _attn_fwd(
585
592
@triton .autotune (
586
593
configs = list (filter (keep , configs )),
587
594
key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" , "warp_specialize" ],
588
- prune_configs_by = {"early_config_prune" : prune_invalid_configs },
595
+ prune_configs_by = {"early_config_prune" : prune_persistent_configs },
589
596
)
590
597
@triton .jit
591
598
def _attn_fwd_persist (
@@ -597,7 +604,7 @@ def _attn_fwd_persist(
597
604
desc_k ,
598
605
desc_v ,
599
606
desc_o ,
600
- N_CTX , # : tl.constexpr, #
607
+ N_CTX : tl .constexpr , #
601
608
HEAD_DIM : tl .constexpr , #
602
609
BLOCK_M : tl .constexpr , #
603
610
BLOCK_N : tl .constexpr , #
0 commit comments