@@ -241,14 +241,14 @@ def __init__(self, channel, phase, index):
241
241
242
242
@gluon .jit
243
243
def acquire (self ):
244
- smem , ready_bar = self .channel .acquire_producer (self .index , self .phase )
244
+ mem , ready_bar = self .channel .acquire_producer (self .index , self .phase )
245
245
self .index , self .phase = self .channel .increment (self .index , self .phase )
246
- return smem , ready_bar , self
246
+ return mem , ready_bar , self
247
247
248
248
@gluon .jit
249
249
def emplace (self , value ):
250
- smem , ready_bar , self = self .acquire ()
251
- smem .store (value )
250
+ mem , ready_bar , self = self .acquire ()
251
+ mem .store (value )
252
252
mbarrier .arrive (ready_bar , count = 1 )
253
253
return self
254
254
@@ -265,14 +265,14 @@ def __init__(self, channel, phase, index):
265
265
266
266
@gluon .jit
267
267
def acquire (self ):
268
- smem , empty_bar = self .channel .acquire_consumer (self .index , self .phase )
268
+ mem , empty_bar = self .channel .acquire_consumer (self .index , self .phase )
269
269
self .index , self .phase = self .channel .increment (self .index , self .phase )
270
- return smem , empty_bar , self
270
+ return mem , empty_bar , self
271
271
272
272
@gluon .jit
273
273
def get (self , layout : gl .constexpr ):
274
- smem , empty_bar , self = self .acquire ()
275
- value = smem .load (layout )
274
+ mem , empty_bar , self = self .acquire ()
275
+ value = mem .load (layout )
276
276
mbarrier .arrive (empty_bar , count = 1 )
277
277
return value , self
278
278
@@ -399,9 +399,9 @@ class AttentionConfig:
399
399
dtype : gl .constexpr
400
400
num_warps : gl .constexpr
401
401
402
- SPLIT_N_FACTOR : gl .constexpr
402
+ SPLIT_D_FACTOR : gl .constexpr
403
403
SPLIT_M : gl .constexpr
404
- SPLIT_N : gl .constexpr
404
+ SPLIT_D : gl .constexpr
405
405
406
406
q_shape : gl .constexpr
407
407
k_shape : gl .constexpr
@@ -416,8 +416,11 @@ class AttentionConfig:
416
416
qk_layout : gl .constexpr
417
417
o_layout : gl .constexpr
418
418
o_splitn_layout : gl .constexpr
419
+ mi_2d_layout : gl .constexpr
419
420
420
- def __init__ (self , qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , dtype , num_warps , SPLIT_N_FACTOR ):
421
+ mi_use_tmem : gl .constexpr
422
+
423
+ def __init__ (self , qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , dtype , num_warps , SPLIT_D_FACTOR ):
421
424
self .qk_scale = qk_scale
422
425
self .Z = Z
423
426
self .H = H
@@ -428,9 +431,9 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num
428
431
self .dtype = gl .constexpr (dtype )
429
432
self .num_warps = gl .constexpr (num_warps )
430
433
431
- self .SPLIT_N_FACTOR = SPLIT_N_FACTOR
434
+ self .SPLIT_D_FACTOR = SPLIT_D_FACTOR
432
435
self .SPLIT_M = self .BLOCK_M // 2
433
- self .SPLIT_N = self .BLOCK_N // self .SPLIT_N_FACTOR
436
+ self .SPLIT_D = self .HEAD_DIM // self .SPLIT_D_FACTOR
434
437
435
438
self .q_shape = gl .constexpr ([self .SPLIT_M , self .HEAD_DIM ])
436
439
self .k_shape = gl .constexpr ([self .BLOCK_N , self .HEAD_DIM ])
@@ -447,8 +450,11 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num
447
450
self .qk_layout = gl .constexpr (get_tmem_32x32b_reg_layout (qk_instr_shape , self .qk_shape , self .num_warps ))
448
451
self .o_layout = gl .constexpr (get_tmem_32x32b_reg_layout (o_instr_shape , self .o_shape , self .num_warps ))
449
452
self .o_splitn_layout = gl .constexpr (
450
- get_tmem_32x32b_reg_layout ((o_instr_shape [0 ], o_instr_shape [1 ] // self .SPLIT_N_FACTOR , o_instr_shape [2 ]),
451
- (self .o_shape [0 ], self .o_shape [1 ] // self .SPLIT_N_FACTOR ), self .num_warps ))
453
+ get_tmem_32x32b_reg_layout ((o_instr_shape [0 ], o_instr_shape [1 ] // self .SPLIT_D_FACTOR , o_instr_shape [2 ]),
454
+ (self .o_shape [0 ], self .o_shape [1 ] // self .SPLIT_D_FACTOR ), self .num_warps ))
455
+ self .mi_2d_layout = gl .constexpr (gl .BlockedLayout ([1 , 1 ], [32 , 1 ], [4 , 1 ], [0 , 1 ]))
456
+
457
+ self .mi_use_tmem = gl .constexpr (True )
452
458
453
459
@gluon .jit
454
460
def get_program (self ):
@@ -539,7 +545,7 @@ class InnerLoopInfo:
539
545
qk_mma_ctx : MMAContext
540
546
o_mma_ctx : MMAContext
541
547
p_chnl : TensorMemoryChannel
542
- mi_chnl : SharedMemoryChannel
548
+ mi_chnl : TensorMemoryChannel
543
549
li_smem : gl .shared_memory_descriptor
544
550
q_smem : gl .shared_memory_descriptor
545
551
@@ -552,14 +558,25 @@ def create(config, tile):
552
558
o_mma_ctx .channel .initialize_for_consumer ()
553
559
o_mma_ctx .channel .mem .index (0 ).store (tile .acc )
554
560
555
- p_chnl = TensorMemoryChannel ._borrow (qk_mma_ctx .channel .mem , config .qk_shape , config .dtype ,
556
- config .p_tmem_layout , num_buffers = 1 , num_consumers = 1 )
561
+ # QK and PV MMAs are serialized, which enables borrowing QK's memory.
562
+ borrow_tmem = qk_mma_ctx .channel .mem .index (0 )
563
+ p_tmem = borrow_tmem .slice (0 , config .BLOCK_N // 2 )
564
+ mi_tmem = borrow_tmem .slice (config .BLOCK_N // 2 , 1 )
565
+ mi_layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], unpacked = False )
566
+
567
+ p_chnl = TensorMemoryChannel ._borrow (p_tmem , config .qk_shape , config .dtype , config .p_tmem_layout , num_buffers = 1 ,
568
+ num_consumers = 1 )
557
569
p_chnl .initialize_for_producer ()
558
570
559
- mi_chnl = SharedMemoryChannel .create ([config .SPLIT_M ], gl .float32 , gl .constexpr (mbarrier .MBarrierLayout ()),
560
- num_buffers = 1 )
571
+ if config .mi_use_tmem :
572
+ mi_chnl = TensorMemoryChannel ._borrow (mi_tmem , [config .SPLIT_M , 1 ], gl .float32 , mi_layout , num_buffers = 1 )
573
+ m_i = gl .convert_layout (tile .m_i .expand_dims (1 ), config .mi_2d_layout )
574
+ else :
575
+ mi_chnl = SharedMemoryChannel .create ([config .SPLIT_M ], gl .float32 , gl .constexpr (mbarrier .MBarrierLayout ()),
576
+ num_buffers = 1 )
577
+ m_i = tile .m_i
578
+ mi_chnl .mem .index (0 ).store (m_i )
561
579
mi_chnl .initialize_for_producer ()
562
- mi_chnl .mem .index (0 ).store (tile .m_i )
563
580
564
581
li_smem = gl .allocate_shared_memory (gl .float32 , [config .SPLIT_M ], gl .constexpr (mbarrier .MBarrierLayout ()))
565
582
li_smem .store (tile .l_i )
@@ -662,21 +679,66 @@ def _attn_fwd_mma(config, #
662
679
mbarrier .invalidate (qk_p_bar )
663
680
664
681
682
+ @gluon .jit
683
+ def _add_f32x2 (a , b ):
684
+ return gl .inline_asm_elementwise (
685
+ """
686
+ {
687
+ .reg .b64 ra, rb, rc;
688
+ mov.b64 ra, { $2, $3 };
689
+ mov.b64 rb, { $4, $5 };
690
+ add.f32x2 rc, ra, rb;
691
+ mov.b64 { $0, $1 }, rc;
692
+ }
693
+ """ ,
694
+ "=r,=r,r,r,r,r" ,
695
+ [a , b ],
696
+ dtype = gl .float32 ,
697
+ is_pure = True ,
698
+ pack = 2 ,
699
+ )
700
+
701
+
702
+ @gluon .jit
703
+ def _mul_f32x2 (a , b ):
704
+ return gl .inline_asm_elementwise (
705
+ """
706
+ {
707
+ .reg .b64 ra, rb, rc;
708
+ mov.b64 ra, { $2, $3 };
709
+ mov.b64 rb, { $4, $5 };
710
+ mul.f32x2 rc, ra, rb;
711
+ mov.b64 { $0, $1 }, rc;
712
+ }
713
+ """ ,
714
+ "=r,=r,r,r,r,r" ,
715
+ [a , b ],
716
+ dtype = gl .float32 ,
717
+ is_pure = True ,
718
+ pack = 2 ,
719
+ )
720
+
721
+
665
722
@gluon .jit
666
723
def _attn_fwd_correction_compute (config , mi_consumer , o_consumer , m_i ):
667
- m_ij , mi_consumer = mi_consumer .get (gl .constexpr (gl .SliceLayout (1 , config .o_splitn_layout )))
724
+ mi_layout : gl .constexpr = gl .SliceLayout (1 , config .o_splitn_layout )
725
+ if config .mi_use_tmem :
726
+ m_ij , mi_consumer = mi_consumer .get (config .mi_2d_layout )
727
+ m_ij = gl .convert_layout (m_ij .reshape ([config .SPLIT_M ]), mi_layout )
728
+ else :
729
+ m_ij , mi_consumer = mi_consumer .get (mi_layout )
668
730
alpha = gl .exp2 (m_i - m_ij )
669
731
670
732
o_tmem , o_bar , o_consumer = o_consumer .acquire ()
671
- if config .SPLIT_N_FACTOR == 1 :
733
+ if config .SPLIT_D_FACTOR == 1 :
672
734
o = o_tmem .load (config .o_layout )
673
- o = o * alpha [:, None ]
735
+ o = _mul_f32x2 ( o , alpha [:, None ])
674
736
o_tmem .store (o )
675
737
else :
676
- for i in tl .static_range (config .SPLIT_N_FACTOR ):
677
- o_ref = o_tmem .slice (i * config .SPLIT_N , config .SPLIT_N )
738
+ for i in tl .static_range (config .SPLIT_D_FACTOR ):
739
+ o_ref = o_tmem .slice (i * config .SPLIT_D , config .SPLIT_D )
678
740
o = o_ref .load (config .o_splitn_layout )
679
- o = o * alpha [:, None ]
741
+ o = _mul_f32x2 ( o , alpha [:, None ])
680
742
o_ref .store (o )
681
743
mbarrier .arrive (o_bar , count = 1 )
682
744
return mi_consumer , o_consumer , m_ij
@@ -723,31 +785,48 @@ def _softmax_tile(tile_id: gl.constexpr, config, info, STAGE: gl.constexpr):
723
785
p_producer = info .p_chnl .create_producer ()
724
786
mi_producer = info .mi_chnl .create_producer ()
725
787
726
- m_i = info .mi_chnl .mem .index (0 ).load (qk_slice_dim1 )
788
+ if config .mi_use_tmem :
789
+ m_i = info .mi_chnl .mem .index (0 ).load (config .mi_2d_layout )
790
+ m_i = gl .convert_layout (m_i .reshape ([config .SPLIT_M ]), qk_slice_dim1 )
791
+ else :
792
+ m_i = info .mi_chnl .mem .index (0 ).load (qk_slice_dim1 )
727
793
l_i = info .li_smem .load (qk_slice_dim1 )
728
794
729
795
for start_n in range (lo , hi , config .BLOCK_N ):
730
796
qk , qk_consumer = qk_consumer .get (config .qk_layout )
797
+ if config .HEAD_DIM == 128 :
798
+ p_tmem , p_bar , p_producer = p_producer .acquire ()
799
+
731
800
if STAGE == 2 :
732
801
# Prevent LLVM from hoisting the partial sums, which triggers spilling.
733
802
offs_n = gl .inline_asm_elementwise ("mov.b32 $0, $0;" , "=r,r" , [offs_n ], dtype = gl .int32 , is_pure = True ,
734
803
pack = 1 )
735
804
mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
736
- qk = qk * config . qk_scale + gl .where (mask , 0 , - 1.0e6 )
737
- m_ij = gl .maximum (m_i , gl .max (qk , 1 ))
738
- mi_producer = mi_producer . emplace ( m_ij )
739
- qk -= m_ij [:, None ]
805
+ qk = gl .where (mask , qk , - 1.0e8 )
806
+ m_ij = gl .maximum (m_i , gl .max (qk , 1 ) * config . qk_scale )
807
+ if config . mi_use_tmem :
808
+ mi_producer = mi_producer . emplace ( gl . convert_layout ( m_ij . expand_dims ( 1 ), config . mi_2d_layout ))
740
809
else :
741
- m_ij = gl .maximum (m_i , gl .max (qk , 1 ) * config .qk_scale )
742
810
mi_producer = mi_producer .emplace (m_ij )
743
- qk = qk * config .qk_scale - m_ij [:, None ]
744
-
745
- p = gl .exp2 (qk )
811
+ qk = qk * config .qk_scale - m_ij [:, None ]
746
812
747
- l_ij = gl .sum (p , 1 )
748
- alpha = gl .exp2 (m_i - m_ij )
749
-
750
- p_producer = p_producer .emplace (p .to (config .dtype ))
813
+ if config .HEAD_DIM == 64 :
814
+ p = gl .exp2 (qk )
815
+ l_ij = gl .sum (p , 1 )
816
+ alpha = gl .exp2 (m_i - m_ij )
817
+ p_producer = p_producer .emplace (p .to (config .dtype ))
818
+ else :
819
+ qk0 , qk1 , = qk .reshape ([config .SPLIT_M , 2 , config .BLOCK_N // 2 ]).permute (0 , 2 , 1 ).split ()
820
+ p0 = gl .exp2 (qk0 )
821
+ p_tmem .slice (0 , config .BLOCK_N // 2 ).store (p0 .to (config .dtype ))
822
+ p1 = gl .exp2 (qk1 )
823
+ p_tmem .slice (config .BLOCK_N // 2 , config .BLOCK_N // 2 ).store (p1 .to (config .dtype ))
824
+ mbarrier .arrive (p_bar , count = 1 )
825
+ p = gl .join (p0 , p1 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N ])
826
+ p = gl .convert_layout (p , config .qk_layout )
827
+
828
+ l_ij = gl .sum (p , 1 )
829
+ alpha = gl .exp2 (m_i - m_ij )
751
830
752
831
l_i = l_i * alpha + l_ij
753
832
m_i = m_ij
@@ -773,7 +852,7 @@ def _attn_fwd_softmax1(config, #
773
852
def _attn_fwd_inner (config , info0 , info1 , m_i0 , m_i1 , #
774
853
desc_k , desc_v , #
775
854
STAGE : gl .constexpr ):
776
- num_buffers : gl .constexpr = 2 if config .HEAD_DIM > = 128 else 3
855
+ num_buffers : gl .constexpr = 2 if config .HEAD_DIM = = 128 else 3
777
856
k_load_ctx = LoadContext .create (desc_k , num_buffers = num_buffers , num_consumers = 2 )
778
857
v_load_ctx = LoadContext .create (desc_v , num_buffers = num_buffers , num_consumers = 2 )
779
858
@@ -793,7 +872,7 @@ def _attn_fwd_inner(config, info0, info1, m_i0, m_i1, #
793
872
_attn_fwd_softmax1 ,
794
873
_attn_fwd_mma ,
795
874
_attn_fwd_load ,
796
- ], [4 , 4 , 1 , 1 ], [192 , 200 , 32 , 32 ])
875
+ ], [4 , 4 , 1 , 1 ], [192 , 192 , 32 , 32 ])
797
876
798
877
k_load_ctx .release ()
799
878
v_load_ctx .release ()
@@ -809,8 +888,7 @@ def _gluon_attn(sm_scale, M, Z, H, N_CTX, #
809
888
num_warps : gl .constexpr ):
810
889
qk_scale = sm_scale
811
890
qk_scale *= 1.44269504
812
- config = AttentionConfig (qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , dtype , num_warps ,
813
- SPLIT_N_FACTOR = triton .cdiv (HEAD_DIM , 64 ))
891
+ config = AttentionConfig (qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , dtype , num_warps , SPLIT_D_FACTOR = 2 )
814
892
815
893
prog = config .get_program ()
816
894
@@ -909,7 +987,7 @@ def is_blackwell():
909
987
@pytest .mark .parametrize ("H" , [2 , 48 ])
910
988
@pytest .mark .parametrize ("N_CTX" , [256 , 1024 , 4 * 1024 ])
911
989
@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
912
- @pytest .mark .parametrize ("causal" , [True ])
990
+ @pytest .mark .parametrize ("causal" , [False , True ])
913
991
@pytest .mark .parametrize ("dtype" , [torch .float16 ])
914
992
@pytest .mark .skipif (not is_blackwell (), reason = "Gluon attention is only supported on Blackwell GPUs" )
915
993
def test_op (Z , H , N_CTX , HEAD_DIM , causal , dtype ):
0 commit comments