@@ -76,7 +76,15 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
76
76
alpha = tl .math .exp2 (m_i - m_ij )
77
77
l_ij = tl .sum (p , 1 )
78
78
# -- update output accumulator --
79
- acc = acc * alpha [:, None ]
79
+ if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128 :
80
+ BM : tl .constexpr = acc .shape [0 ]
81
+ BN : tl .constexpr = acc .shape [1 ]
82
+ acc0 , acc1 = acc .reshape ([BM , 2 , BN // 2 ]).permute (0 , 2 , 1 ).split ()
83
+ acc0 = acc0 * alpha [:, None ]
84
+ acc1 = acc1 * alpha [:, None ]
85
+ acc = tl .join (acc0 , acc1 ).permute (0 , 2 , 1 ).reshape ([BM , BN ])
86
+ else :
87
+ acc = acc * alpha [:, None ]
80
88
# prepare p and v for the dot
81
89
v = desc_v .load ([offsetkv_y , 0 ])
82
90
p = p .to (dtype )
@@ -119,7 +127,7 @@ def _host_descriptor_pre_hook(nargs):
119
127
if "PYTEST_VERSION" in os .environ :
120
128
# Use a single config in testing for reproducibility
121
129
configs = [
122
- triton .Config (dict (BLOCK_M = 64 , BLOCK_N = 64 ), num_stages = 2 , num_warps = 4 , pre_hook = _host_descriptor_pre_hook ),
130
+ triton .Config (dict (BLOCK_M = 128 , BLOCK_N = 64 ), num_stages = 2 , num_warps = 4 , pre_hook = _host_descriptor_pre_hook ),
123
131
]
124
132
125
133
@@ -505,7 +513,10 @@ def grid(META):
505
513
506
514
ctx .grid = grid
507
515
if is_cuda () and warp_specialize :
508
- extra_kern_args ["maxnreg" ] = 80
516
+ if HEAD_DIM_K == 128 and q .dtype == torch .float16 :
517
+ extra_kern_args ["maxnreg" ] = 168
518
+ else :
519
+ extra_kern_args ["maxnreg" ] = 80
509
520
_attn_fwd [grid ](
510
521
sm_scale , M , #
511
522
q .shape [0 ], q .shape [1 ], #
@@ -620,36 +631,35 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16)
620
631
HAS_FLASH = False
621
632
622
633
TORCH_HAS_FP8 = hasattr (torch , 'float8_e5m2' )
623
- BATCH , N_HEADS , HEAD_DIM = 4 , 32 , 64
634
+ BATCH , N_HEADS = 4 , 32
624
635
# vary seq length for fixed head and batch=4
625
636
configs = []
626
- for mode in ["fwd" , "bwd" ]:
627
- for causal in [True , False ]:
628
- for warp_specialize in [False , True ] if is_blackwell () else [False ]:
629
- if mode == "bwd" and not causal :
630
- continue
631
- configs .append (
632
- triton .testing .Benchmark (
633
- x_names = ["N_CTX" ],
634
- x_vals = [2 ** i for i in range (10 , 15 )],
635
- line_arg = "provider" ,
636
- line_vals = ["triton-fp16" ] + (["triton-fp8" ] if TORCH_HAS_FP8 else []) +
637
- (["flash" ] if HAS_FLASH else []),
638
- line_names = ["Triton [FP16]" ] + (["Triton [FP8]" ] if TORCH_HAS_FP8 else []) +
639
- (["Flash-2" ] if HAS_FLASH else []),
640
- styles = [("red" , "-" ), ("blue" , "-" ), ("green" , "-" )],
641
- ylabel = "TFLOPS" ,
642
- plot_name =
643
- f"fused-attention-batch{ BATCH } -head{ N_HEADS } -d{ HEAD_DIM } -{ mode } -causal={ causal } -warp_specialize={ warp_specialize } " ,
644
- args = {
645
- "H" : N_HEADS ,
646
- "BATCH" : BATCH ,
647
- "HEAD_DIM" : HEAD_DIM ,
648
- "mode" : mode ,
649
- "causal" : causal ,
650
- "warp_specialize" : warp_specialize ,
651
- },
652
- ))
637
+ for HEAD_DIM in [64 , 128 ]:
638
+ for mode in ["fwd" , "bwd" ]:
639
+ for causal in [True , False ]:
640
+ for warp_specialize in [False , True ] if is_blackwell () else [False ]:
641
+ configs .append (
642
+ triton .testing .Benchmark (
643
+ x_names = ["N_CTX" ],
644
+ x_vals = [2 ** i for i in range (10 , 15 )],
645
+ line_arg = "provider" ,
646
+ line_vals = ["triton-fp16" ] + (["triton-fp8" ] if TORCH_HAS_FP8 else []) +
647
+ (["flash" ] if HAS_FLASH else []),
648
+ line_names = ["Triton [FP16]" ] + (["Triton [FP8]" ] if TORCH_HAS_FP8 else []) +
649
+ (["Flash-2" ] if HAS_FLASH else []),
650
+ styles = [("red" , "-" ), ("blue" , "-" ), ("green" , "-" )],
651
+ ylabel = "TFLOPS" ,
652
+ plot_name =
653
+ f"fused-attention-batch{ BATCH } -head{ N_HEADS } -d{ HEAD_DIM } -{ mode } -causal={ causal } -warp_specialize={ warp_specialize } " ,
654
+ args = {
655
+ "H" : N_HEADS ,
656
+ "BATCH" : BATCH ,
657
+ "HEAD_DIM" : HEAD_DIM ,
658
+ "mode" : mode ,
659
+ "causal" : causal ,
660
+ "warp_specialize" : warp_specialize ,
661
+ },
662
+ ))
653
663
654
664
655
665
@triton .testing .perf_report (configs )
0 commit comments