@@ -530,6 +530,25 @@ def _compute_offsets_persistent(tile_idx, n_tile_num, H, N_CTX, BLOCK_M):
530
530
return start_m , off_hz , lo , hi , qo_offset_y , kv_offset_y
531
531
532
532
533
+ @triton .jit
534
+ def _split_n (x , SPLIT_FACTOR : tl .constexpr ):
535
+ if SPLIT_FACTOR == 1 :
536
+ return (x , )
537
+ else :
538
+ x0 , x1 = x .reshape ([x .shape [0 ], 2 , x .shape [1 ] // 2 ]).permute (0 , 2 , 1 ).split ()
539
+ return _split_n (x0 , SPLIT_FACTOR // 2 ) + _split_n (x1 , SPLIT_FACTOR // 2 )
540
+
541
+ @triton .jit
542
+ def _join_n (xs ):
543
+ if len (xs ) == 1 :
544
+ return xs [0 ]
545
+ else :
546
+ x0 = _join_n (xs [:len (xs ) // 2 ])
547
+ x1 = _join_n (xs [len (xs ) // 2 :])
548
+ x = tl .join (x0 , x1 ).permute (0 , 2 , 1 ).reshape ([x0 .shape [0 ], x0 .shape [1 ] * 2 ])
549
+ return x
550
+
551
+
533
552
@triton .autotune (configs = configs_persistent , key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" ])
534
553
@triton .jit
535
554
def _attn_fwd_ws_persistent (
@@ -711,12 +730,12 @@ def _attn_fwd_ws_persistent(
711
730
for cid in tl .static_range (0 , NUM_MMA_GROUPS ):
712
731
tlx .barrier_wait (o_fulls [cid ], phase )
713
732
tlx .fence_async_shared ()
714
- tlx .barrier_arrive (o_empties [cid ])
715
733
qo_offset_y_split = qo_offset_y + cid * BLOCK_M_SPLIT
716
734
tlx .async_descriptor_store (
717
735
desc_o , o_tiles [cid ], [qo_offset_y_split , 0 ]
718
736
)
719
737
tlx .async_descriptor_store_wait (0 )
738
+ tlx .barrier_arrive (o_empties [cid ])
720
739
721
740
tile_idx += num_progs
722
741
@@ -751,17 +770,27 @@ def _attn_fwd_ws_persistent(
751
770
tlx .local_store (alpha_tiles [cid * HEAD_DIM ], alpha [:, None ])
752
771
tlx .barrier_arrive (alpha_fulls [cid ])
753
772
754
- qk = _fma_f32x2 (qk , qk_scale , - m_ij [:, None ])
755
- p = tl .math .exp2 (qk )
756
- l_ij = tl .sum (p , 1 )
757
- p = p .to (tlx .dtype_of (desc_v ))
758
773
759
774
# prepare p for the v dot
760
775
# Use p[1] for cid=0, and p[3] for cid=1
761
776
p_bufIdx = 1 + cid * NUM_MMA_GROUPS
762
- tlx .local_store (p_tiles [p_bufIdx ], p )
763
- tlx .barrier_arrive (p_fulls [cid ])
764
777
778
+ qk = _fma_f32x2 (qk , qk_scale , - m_ij [:, None ])
779
+ qks = _split_n (qk , NUM_MMA_SLICES )
780
+ ps = ()
781
+ for slice_id in tl .static_range (0 , NUM_MMA_SLICES ):
782
+ p_i = tl .math .exp2 (qks [slice_id ])
783
+ p_slice = tlx .subslice (
784
+ p_tiles [p_bufIdx ],
785
+ HEAD_DIM * slice_id // NUM_MMA_SLICES ,
786
+ HEAD_DIM // NUM_MMA_SLICES ,
787
+ )
788
+ tlx .local_store (p_slice , p_i .to (tlx .dtype_of (desc_v )))
789
+ ps = ps + (p_i , )
790
+
791
+ tlx .barrier_arrive (p_fulls [cid ])
792
+ p = _join_n (ps )
793
+ l_ij = tl .sum (p , 1 )
765
794
l_i = l_i * alpha + l_ij
766
795
m_i = m_ij
767
796
accum_cnt_qk += 1
0 commit comments