Skip to content

Commit b03df33

Browse files
authored
[Blackwell_Attention] [Triton] Make N_CTX const in DP FA kernel
1 parent e7c435c commit b03df33

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

tritonbench/kernels/attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
PEEL_LAST = os.getenv("PEEL_LAST_ITER")
1717
WITH_TMA = os.getenv("WITH_TMA")
1818
HAS_EXPLICIT_WS = os.getenv("ENABLE_EXPLICIT_WS")
19-
SUPPORT_GLUON = os.getenv("WITH_GLUON")
19+
SUPPORT_GLUON = os.getenv("WITH_GLUON") == "1"
2020
WITH_MAXNREG = os.getenv("WITH_MAXNREG")
2121

2222

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
291291
configs = [
292292
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce)
293293
for BM in [256]
294-
for BN in [128]
294+
for BN in [64, 128]
295295
for s in NUM_STAGES_OPTIONS
296296
for w in [4]
297297
for subtile in [True]
@@ -318,6 +318,13 @@ def prune_invalid_configs(configs, named_args, **kwargs):
318318
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
319319

320320

321+
def prune_persistent_configs(configs, named_args, **kwargs):
322+
N_CTX = kwargs["N_CTX"]
323+
# Filter out configs based on desired BLOCK_n
324+
TARGET_BLOCK_N = 64 if N_CTX == 128 else 128
325+
return [conf for conf in configs if conf.kwargs.get("BLOCK_N", 0) == TARGET_BLOCK_N]
326+
327+
321328
@triton.jit
322329
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
323330
if isinstance(desc_or_ptr, tl.tensor_descriptor):
@@ -399,7 +406,7 @@ def _attn_fwd_tma_dp(
399406
desc_o,
400407
pid,
401408
off_hz,
402-
N_CTX, #
409+
N_CTX: tl.constexpr, #
403410
HEAD_DIM: tl.constexpr, #
404411
BLOCK_M: tl.constexpr, #
405412
BLOCK_N: tl.constexpr, #
@@ -543,7 +550,7 @@ def _attn_fwd(
543550
desc_k,
544551
desc_v,
545552
desc_o,
546-
N_CTX, #
553+
N_CTX: tl.constexpr, #
547554
HEAD_DIM: tl.constexpr, #
548555
BLOCK_M: tl.constexpr, #
549556
BLOCK_N: tl.constexpr, #
@@ -585,7 +592,7 @@ def _attn_fwd(
585592
@triton.autotune(
586593
configs=list(filter(keep, configs)),
587594
key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
588-
prune_configs_by={"early_config_prune": prune_invalid_configs},
595+
prune_configs_by={"early_config_prune": prune_persistent_configs},
589596
)
590597
@triton.jit
591598
def _attn_fwd_persist(
@@ -597,7 +604,7 @@ def _attn_fwd_persist(
597604
desc_k,
598605
desc_v,
599606
desc_o,
600-
N_CTX, #: tl.constexpr, #
607+
N_CTX: tl.constexpr, #
601608
HEAD_DIM: tl.constexpr, #
602609
BLOCK_M: tl.constexpr, #
603610
BLOCK_N: tl.constexpr, #

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929

3030
HAS_BLACKWELL_AUTOWS = True
31-
except (ImportError, IOError, AttributeError):
31+
except (ImportError, IOError, AttributeError, TypeError):
3232
# Needs compiler that supports autoWS
3333
HAS_BLACKWELL_AUTOWS = False
3434

@@ -492,7 +492,7 @@ def gluon_blackwell_tutorial_fwd(
492492
return lambda: gluon_blackwell_fwd(q, k, v, self.causal, self.sm_scale)
493493

494494
# Only works with triton main, forward only.
495-
@register_benchmark(enabled=False)
495+
@register_benchmark(enabled=SUPPORT_GLUON)
496496
def gluon_blackwell_tutorial_persistent_fwd(
497497
self,
498498
q: torch.Tensor,

0 commit comments

Comments
 (0)