@@ -40,13 +40,17 @@ def is_blackwell():
40
40
return is_cuda () and torch .cuda .get_device_capability ()[0 ] == 10
41
41
42
42
43
+ def is_hopper ():
44
+ return is_cuda () and torch .cuda .get_device_capability ()[0 ] == 9
45
+
46
+
43
47
@triton .jit
44
48
def _attn_fwd_inner (acc , l_i , m_i , q , #
45
49
desc_k , desc_v , #
46
50
offset_y , dtype : tl .constexpr , start_m , qk_scale , #
47
51
BLOCK_M : tl .constexpr , HEAD_DIM : tl .constexpr , BLOCK_N : tl .constexpr , #
48
52
STAGE : tl .constexpr , offs_m : tl .constexpr , offs_n : tl .constexpr , #
49
- N_CTX : tl .constexpr , warp_specialize : tl .constexpr ):
53
+ N_CTX : tl .constexpr , warp_specialize : tl .constexpr , SUBTILE_EPILOGUE : tl . constexpr ):
50
54
# range of values handled by this stage
51
55
if STAGE == 1 :
52
56
lo , hi = 0 , start_m * BLOCK_M
@@ -80,7 +84,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
80
84
alpha = tl .math .exp2 (m_i - m_ij )
81
85
l_ij = tl .sum (p , 1 )
82
86
# -- update output accumulator --
83
- if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128 :
87
+ if SUBTILE_EPILOGUE :
84
88
BM : tl .constexpr = acc .shape [0 ]
85
89
BN : tl .constexpr = acc .shape [1 ]
86
90
acc0 , acc1 = acc .reshape ([BM , 2 , BN // 2 ]).permute (0 , 2 , 1 ).split ()
@@ -175,6 +179,7 @@ def _attn_fwd(sm_scale, M, #
175
179
FP8_OUTPUT : tl .constexpr , #
176
180
STAGE : tl .constexpr , #
177
181
warp_specialize : tl .constexpr , #
182
+ SUBTILE_EPILOGUE : tl .constexpr , #
178
183
):
179
184
dtype = tl .float8e5 if FP8_OUTPUT else tl .float16
180
185
tl .static_assert (BLOCK_N <= HEAD_DIM )
@@ -220,15 +225,15 @@ def _attn_fwd(sm_scale, M, #
220
225
offset_y , dtype , start_m , qk_scale , #
221
226
BLOCK_M , HEAD_DIM , BLOCK_N , #
222
227
4 - STAGE , offs_m , offs_n , N_CTX , #
223
- warp_specialize )
228
+ warp_specialize , SUBTILE_EPILOGUE )
224
229
# stage 2: on-band
225
230
if STAGE & 2 :
226
231
acc , l_i , m_i = _attn_fwd_inner (acc , l_i , m_i , q , #
227
232
desc_k , desc_v , #
228
233
offset_y , dtype , start_m , qk_scale , #
229
234
BLOCK_M , HEAD_DIM , BLOCK_N , #
230
235
2 , offs_m , offs_n , N_CTX , #
231
- warp_specialize )
236
+ warp_specialize , SUBTILE_EPILOGUE )
232
237
# epilogue
233
238
m_i += tl .math .log2 (l_i )
234
239
acc = acc / l_i [:, None ]
@@ -504,7 +509,8 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
504
509
extra_kern_args = {"waves_per_eu" : waves_per_eu , "allow_flush_denorm" : True }
505
510
506
511
M = torch .empty ((q .shape [0 ], q .shape [1 ], q .shape [2 ]), device = q .device , dtype = torch .float32 )
507
- if supports_host_descriptor ():
512
+ # Use device_descriptor for Hopper + warpspec.
513
+ if supports_host_descriptor () and not (is_hopper () and warp_specialize ):
508
514
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
509
515
y_dim = q .shape [0 ] * q .shape [1 ] * q .shape [2 ]
510
516
@@ -533,7 +539,8 @@ def grid(META):
533
539
return (triton .cdiv (q .shape [2 ], META ["BLOCK_M" ]), q .shape [0 ] * q .shape [1 ], 1 )
534
540
535
541
ctx .grid = grid
536
- if is_cuda () and warp_specialize :
542
+ SUBTILE_EPILOGUE = False if is_hopper () and warp_specialize else True
543
+ if is_blackwell () and warp_specialize :
537
544
if HEAD_DIM_K == 128 and q .dtype == torch .float16 :
538
545
extra_kern_args ["maxnreg" ] = 168
539
546
else :
@@ -547,7 +554,7 @@ def grid(META):
547
554
FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
548
555
STAGE = stage , #
549
556
warp_specialize = warp_specialize , #
550
- ** extra_kern_args )
557
+ SUBTILE_EPILOGUE = SUBTILE_EPILOGUE , ** extra_kern_args )
551
558
552
559
ctx .save_for_backward (q , k , v , o , M )
553
560
ctx .sm_scale = sm_scale
@@ -684,7 +691,9 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtyp
684
691
for HEAD_DIM in [64 , 128 ]:
685
692
for mode in ["fwd" , "bwd" ]:
686
693
for causal in [True , False ]:
687
- for warp_specialize in [False , True ] if is_blackwell () else [False ]:
694
+ # Enable warpspec for causal fwd on Hopper
695
+ for warp_specialize in [False , True ] if (is_blackwell () or
696
+ (is_hopper () and mode == "fwd" and not causal )) else [False ]:
688
697
configs .append (
689
698
triton .testing .Benchmark (
690
699
x_names = ["N_CTX" ],
0 commit comments