@@ -49,6 +49,22 @@ def _supports_reg_auto_ws():
4949HAS_REG_AUTO_WS = _supports_reg_auto_ws ()
5050
5151
52+ # Check if Triton version supports pingpongAutoWS
53+ # These parameters are only available in triton/tree/ws-3.5
54+ def _supports_pingpong_auto_ws ():
55+ """Check if the current Triton version supports pingpongAutoWS"""
56+ try :
57+ # Try to create a Config with minRegAutoWS to test support
58+ test_config = triton .Config ({}, pingpongAutoWS = True )
59+ return True
60+ except (TypeError , AttributeError ):
61+ # Parameter not supported in this Triton version
62+ return False
63+
64+
65+ HAS_PINGPONG_AUTO_WS = _supports_pingpong_auto_ws ()
66+
67+
5268@triton .jit
5369def _attn_fwd_subtile (
5470 q ,
@@ -65,6 +81,7 @@ def _attn_fwd_subtile(
6581 dtype : tl .constexpr ,
6682 STAGE : tl .constexpr ,
6783 SUBTILING : tl .constexpr ,
84+ SUBTILING_P : tl .constexpr ,
6885 VECT_MUL : tl .constexpr ,
6986 FADD2_REDUCE : tl .constexpr ,
7087):
@@ -80,7 +97,23 @@ def _attn_fwd_subtile(
8097 qk = _fma_f32x2 (qk , qk_scale , - m_ij [:, None ])
8198 else :
8299 qk = qk * qk_scale - m_ij [:, None ]
83- p = tl .math .exp2 (qk )
100+
101+ PM : tl .constexpr = qk .shape [0 ]
102+ PN : tl .constexpr = qk .shape [1 ]
103+
104+ if SUBTILING_P :
105+ qk0 , qk1 = qk .reshape ([PM , 2 , PN // 2 ]).permute (0 , 2 , 1 ).split ()
106+
107+ p0 = tl .math .exp2 (qk0 )
108+ p0_bf16 = p0 .to (dtype )
109+ p1 = tl .math .exp2 (qk1 )
110+ p1_bf16 = p1 .to (dtype )
111+
112+ p = tl .join (p0 , p1 ).permute (0 , 2 , 1 ).reshape ([PM , PN ])
113+ p_bf16 = tl .join (p0_bf16 , p1_bf16 ).permute (0 , 2 , 1 ).reshape ([PM , PN ])
114+ else :
115+ p = tl .math .exp2 (qk )
116+
84117 # -- compute correction factor
85118 alpha = tl .math .exp2 (m_i - m_ij )
86119 if not FADD2_REDUCE :
@@ -104,8 +137,6 @@ def _attn_fwd_subtile(
104137
105138 # update m_i and l_i
106139 # place this at the end of the loop to reduce register pressure
107- PM : tl .constexpr = p .shape [0 ]
108- PN : tl .constexpr = p .shape [1 ]
109140 if FADD2_REDUCE :
110141 p0 , p1 = p .reshape ([PM , 2 , PN // 2 ]).permute (0 , 2 , 1 ).split ()
111142 l_ij0 , l_ij1 = tl .reduce ((p0 , p1 ), axis = 1 , combine_fn = _reduce_fadd2 )
@@ -115,9 +146,10 @@ def _attn_fwd_subtile(
115146 # We can potentially move these to be before updating l_ij, so the dot
116147 # is not blocked.
117148 # prepare p and v for the dot
118- p = p .to (dtype )
149+ if not SUBTILING_P :
150+ p_bf16 = p .to (dtype )
119151 # note that this non transposed v for FP8 is only supported on Blackwell
120- acc = tl .dot (p , v , acc )
152+ acc = tl .dot (p_bf16 , v , acc )
121153 if not FADD2_REDUCE :
122154 l_i0 = l_i0 * alpha + l_ij
123155 m_i = m_ij
@@ -153,6 +185,7 @@ def _attn_fwd_inner_oss_dp(
153185 N_CTX : tl .constexpr ,
154186 warp_specialize : tl .constexpr ,
155187 SUBTILING : tl .constexpr ,
188+ SUBTILING_P : tl .constexpr ,
156189 VECT_MUL : tl .constexpr ,
157190 FADD2_REDUCE : tl .constexpr ,
158191):
@@ -191,6 +224,7 @@ def _attn_fwd_inner_oss_dp(
191224 dtype ,
192225 STAGE ,
193226 SUBTILING ,
227+ SUBTILING_P ,
194228 VECT_MUL ,
195229 FADD2_REDUCE ,
196230 )
@@ -209,6 +243,7 @@ def _attn_fwd_inner_oss_dp(
209243 dtype ,
210244 STAGE ,
211245 SUBTILING ,
246+ SUBTILING_P ,
212247 VECT_MUL ,
213248 FADD2_REDUCE ,
214249 )
@@ -242,12 +277,13 @@ def _host_descriptor_pre_hook(nargs):
242277
243278if is_tile_enabled ():
244279 # Helper to build config with optional minRegAutoWS/maxRegAutoWS
245- def make_tile_config (BM , BN , occ , subtile , vectmul , add2reduce ):
280+ def make_tile_config (BM , BN , occ , subtile , subtile_p , vectmul , add2reduce ):
246281 config_kwargs = {
247282 "BLOCK_M" : BM ,
248283 "BLOCK_N" : BN ,
249284 "occupancy" : occ ,
250285 "SUBTILING" : subtile ,
286+ "SUBTILING_P" : subtile_p ,
251287 "VECT_MUL" : vectmul ,
252288 "FADD2_REDUCE" : add2reduce ,
253289 }
@@ -258,24 +294,31 @@ def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce):
258294 extra_kwargs ["minRegAutoWS" ] = 24
259295 extra_kwargs ["maxRegAutoWS" ] = 152
260296
297+ if HAS_PINGPONG_AUTO_WS :
298+ extra_kwargs ["pingpongAutoWS" ] = True
299+
261300 return triton .Config (config_kwargs , ** extra_kwargs )
262301
263302 configs = [
264- make_tile_config (BM , BN , occ , subtile , vectmul , add2reduce )
303+ make_tile_config (BM , BN , occ , subtile , subtile_p , vectmul , add2reduce )
265304 for BM in [64 , 128 , 256 ]
266305 for BN in [64 , 128 ]
267306 for occ in [1 , 2 ]
268307 for subtile in [True ]
308+ for subtile_p in [True ]
269309 for vectmul in [0 ]
270310 for add2reduce in [False ]
271311 ]
272312else :
273313 # Helper to build config with optional minRegAutoWS/maxRegAutoWS
274- def make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce , maxreg ):
314+ def make_standard_config (
315+ BM , BN , s , w , subtile , subtile_p , vectmul , add2reduce , maxreg
316+ ):
275317 config_kwargs = {
276318 "BLOCK_M" : BM ,
277319 "BLOCK_N" : BN ,
278320 "SUBTILING" : subtile ,
321+ "SUBTILING_P" : subtile_p ,
279322 "VECT_MUL" : vectmul ,
280323 "FADD2_REDUCE" : add2reduce ,
281324 }
@@ -290,15 +333,21 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg):
290333 extra_kwargs ["minRegAutoWS" ] = 24
291334 extra_kwargs ["maxRegAutoWS" ] = maxreg
292335
336+ if HAS_PINGPONG_AUTO_WS :
337+ extra_kwargs ["pingpongAutoWS" ] = True
338+
293339 return triton .Config (config_kwargs , ** extra_kwargs )
294340
295341 configs = [
296- make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce , maxreg )
342+ make_standard_config (
343+ BM , BN , s , w , subtile , subtile_p , vectmul , add2reduce , maxreg
344+ )
297345 for BM in [256 ]
298346 for BN in [64 , 128 ]
299347 for s in NUM_STAGES_OPTIONS
300348 for w in [4 ]
301349 for subtile in [True ]
350+ for subtile_p in [True ]
302351 for vectmul in [1 ]
303352 for add2reduce in [False ]
304353 for maxreg in [152 , 192 ]
@@ -420,6 +469,7 @@ def _attn_fwd_tma_dp(
420469 warp_specialize : tl .constexpr , #
421470 dtype : tl .constexpr ,
422471 SUBTILING : tl .constexpr ,
472+ SUBTILING_P : tl .constexpr ,
423473 VECT_MUL : tl .constexpr ,
424474 FADD2_REDUCE : tl .constexpr ,
425475):
@@ -485,6 +535,7 @@ def _attn_fwd_tma_dp(
485535 N_CTX , #
486536 warp_specialize ,
487537 SUBTILING ,
538+ SUBTILING_P ,
488539 VECT_MUL ,
489540 FADD2_REDUCE ,
490541 )
@@ -516,6 +567,7 @@ def _attn_fwd_tma_dp(
516567 N_CTX , #
517568 warp_specialize ,
518569 SUBTILING ,
570+ SUBTILING_P ,
519571 VECT_MUL ,
520572 FADD2_REDUCE ,
521573 )
@@ -564,6 +616,7 @@ def _attn_fwd(
564616 warp_specialize : tl .constexpr , #
565617 dtype : tl .constexpr ,
566618 SUBTILING : tl .constexpr ,
619+ SUBTILING_P : tl .constexpr ,
567620 VECT_MUL : tl .constexpr ,
568621 FADD2_REDUCE : tl .constexpr ,
569622):
@@ -589,6 +642,7 @@ def _attn_fwd(
589642 warp_specialize ,
590643 dtype ,
591644 SUBTILING ,
645+ SUBTILING_P ,
592646 VECT_MUL ,
593647 FADD2_REDUCE ,
594648 )
@@ -619,6 +673,7 @@ def _attn_fwd_persist(
619673 OUTER_LOOP : tl .constexpr ,
620674 dtype : tl .constexpr ,
621675 SUBTILING : tl .constexpr ,
676+ SUBTILING_P : tl .constexpr ,
622677 VECT_MUL : tl .constexpr ,
623678 FADD2_REDUCE : tl .constexpr ,
624679):
@@ -682,6 +737,7 @@ def _attn_fwd_persist(
682737 warp_specialize and not OUTER_LOOP ,
683738 dtype ,
684739 SUBTILING ,
740+ SUBTILING_P ,
685741 VECT_MUL ,
686742 FADD2_REDUCE ,
687743 )
0 commit comments