@@ -50,7 +50,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
50
50
offset_y , dtype : tl .constexpr , start_m , qk_scale , #
51
51
BLOCK_M : tl .constexpr , HEAD_DIM : tl .constexpr , BLOCK_N : tl .constexpr , #
52
52
STAGE : tl .constexpr , offs_m : tl .constexpr , offs_n : tl .constexpr , #
53
- N_CTX : tl .constexpr , warp_specialize : tl .constexpr , SUBTILE_EPILOGUE : tl .constexpr ):
53
+ N_CTX : tl .constexpr , warp_specialize : tl .constexpr , IS_HOPPER : tl .constexpr ):
54
54
# range of values handled by this stage
55
55
if STAGE == 1 :
56
56
lo , hi = 0 , start_m * BLOCK_M
@@ -84,7 +84,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
84
84
alpha = tl .math .exp2 (m_i - m_ij )
85
85
l_ij = tl .sum (p , 1 )
86
86
# -- update output accumulator --
87
- if SUBTILE_EPILOGUE :
87
+ if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128 :
88
88
BM : tl .constexpr = acc .shape [0 ]
89
89
BN : tl .constexpr = acc .shape [1 ]
90
90
acc0 , acc1 = acc .reshape ([BM , 2 , BN // 2 ]).permute (0 , 2 , 1 ).split ()
@@ -179,7 +179,7 @@ def _attn_fwd(sm_scale, M, #
179
179
FP8_OUTPUT : tl .constexpr , #
180
180
STAGE : tl .constexpr , #
181
181
warp_specialize : tl .constexpr , #
182
- SUBTILE_EPILOGUE : tl .constexpr , #
182
+ IS_HOPPER : tl .constexpr , #
183
183
):
184
184
dtype = tl .float8e5 if FP8_OUTPUT else tl .float16
185
185
tl .static_assert (BLOCK_N <= HEAD_DIM )
@@ -225,15 +225,15 @@ def _attn_fwd(sm_scale, M, #
225
225
offset_y , dtype , start_m , qk_scale , #
226
226
BLOCK_M , HEAD_DIM , BLOCK_N , #
227
227
4 - STAGE , offs_m , offs_n , N_CTX , #
228
- warp_specialize , SUBTILE_EPILOGUE )
228
+ warp_specialize , IS_HOPPER )
229
229
# stage 2: on-band
230
230
if STAGE & 2 :
231
231
acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , #
232
232
desc_k , desc_v , #
233
233
offset_y , dtype , start_m , qk_scale , #
234
234
BLOCK_M , HEAD_DIM , BLOCK_N , #
235
235
2 , offs_m , offs_n , N_CTX , #
236
- warp_specialize , SUBTILE_EPILOGUE )
236
+ warp_specialize , IS_HOPPER )
237
237
# epilogue
238
238
m_i += tl .math .log2 (l_i )
239
239
acc = acc / l_i [:, None ]
@@ -539,7 +539,6 @@ def grid(META):
539
539
return (triton .cdiv (q .shape [2 ], META ["BLOCK_M" ]), q .shape [0 ] * q .shape [1 ], 1 )
540
540
541
541
ctx .grid = grid
542
- SUBTILE_EPILOGUE = False if is_hopper () and warp_specialize else True
543
542
if is_blackwell () and warp_specialize :
544
543
if HEAD_DIM_K == 128 and q .dtype == torch .float16 :
545
544
extra_kern_args ["maxnreg" ] = 168
@@ -554,7 +553,8 @@ def grid(META):
554
553
FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
555
554
STAGE = stage , #
556
555
warp_specialize = warp_specialize , #
557
- SUBTILE_EPILOGUE = SUBTILE_EPILOGUE , ** extra_kern_args )
556
+ IS_HOPPER = is_hopper (), #
557
+ ** extra_kern_args )
558
558
559
559
ctx .save_for_backward (q , k , v , o , M )
560
560
ctx .sm_scale = sm_scale
@@ -692,8 +692,8 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtyp
692
692
for mode in ["fwd" , "bwd" ]:
693
693
for causal in [True , False ]:
694
694
# Enable warpspec for causal fwd on Hopper
695
- for warp_specialize in [ False , True ] if (is_blackwell () or
696
- ( is_hopper () and mode == "fwd" and not causal )) else [False ]:
695
+ enable_ws = mode == "fwd" and (is_blackwell () or ( is_hopper () and not causal ))
696
+ for warp_specialize in [ False , True ] if enable_ws else [False ]:
697
697
configs .append (
698
698
triton .testing .Benchmark (
699
699
x_names = ["N_CTX" ],
0 commit comments