@@ -241,14 +241,14 @@ def __init__(self, channel, phase, index):
241241
242242 @gluon .jit
243243 def acquire (self ):
244- smem , ready_bar = self .channel .acquire_producer (self .index , self .phase )
244+ mem , ready_bar = self .channel .acquire_producer (self .index , self .phase )
245245 self .index , self .phase = self .channel .increment (self .index , self .phase )
246- return smem , ready_bar , self
246+ return mem , ready_bar , self
247247
248248 @gluon .jit
249249 def emplace (self , value ):
250- smem , ready_bar , self = self .acquire ()
251- smem .store (value )
250+ mem , ready_bar , self = self .acquire ()
251+ mem .store (value )
252252 mbarrier .arrive (ready_bar , count = 1 )
253253 return self
254254
@@ -265,14 +265,14 @@ def __init__(self, channel, phase, index):
265265
266266 @gluon .jit
267267 def acquire (self ):
268- smem , empty_bar = self .channel .acquire_consumer (self .index , self .phase )
268+ mem , empty_bar = self .channel .acquire_consumer (self .index , self .phase )
269269 self .index , self .phase = self .channel .increment (self .index , self .phase )
270- return smem , empty_bar , self
270+ return mem , empty_bar , self
271271
272272 @gluon .jit
273273 def get (self , layout : gl .constexpr ):
274- smem , empty_bar , self = self .acquire ()
275- value = smem .load (layout )
274+ mem , empty_bar , self = self .acquire ()
275+ value = mem .load (layout )
276276 mbarrier .arrive (empty_bar , count = 1 )
277277 return value , self
278278
@@ -399,9 +399,9 @@ class AttentionConfig:
399399 dtype : gl .constexpr
400400 num_warps : gl .constexpr
401401
402- SPLIT_N_FACTOR : gl .constexpr
402+ SPLIT_D_FACTOR : gl .constexpr
403403 SPLIT_M : gl .constexpr
404- SPLIT_N : gl .constexpr
404+ SPLIT_D : gl .constexpr
405405
406406 q_shape : gl .constexpr
407407 k_shape : gl .constexpr
@@ -416,8 +416,11 @@ class AttentionConfig:
416416 qk_layout : gl .constexpr
417417 o_layout : gl .constexpr
418418 o_splitn_layout : gl .constexpr
419+ mi_2d_layout : gl .constexpr
419420
420- def __init__ (self , qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , dtype , num_warps , SPLIT_N_FACTOR ):
421+ mi_use_tmem : gl .constexpr
422+
423+ def __init__ (self , qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , dtype , num_warps , SPLIT_D_FACTOR ):
421424 self .qk_scale = qk_scale
422425 self .Z = Z
423426 self .H = H
@@ -428,9 +431,9 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num
428431 self .dtype = gl .constexpr (dtype )
429432 self .num_warps = gl .constexpr (num_warps )
430433
431- self .SPLIT_N_FACTOR = SPLIT_N_FACTOR
434+ self .SPLIT_D_FACTOR = SPLIT_D_FACTOR
432435 self .SPLIT_M = self .BLOCK_M // 2
433- self .SPLIT_N = self .BLOCK_N // self .SPLIT_N_FACTOR
436+ self .SPLIT_D = self .HEAD_DIM // self .SPLIT_D_FACTOR
434437
435438 self .q_shape = gl .constexpr ([self .SPLIT_M , self .HEAD_DIM ])
436439 self .k_shape = gl .constexpr ([self .BLOCK_N , self .HEAD_DIM ])
@@ -447,8 +450,11 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, dtype, num
447450 self .qk_layout = gl .constexpr (get_tmem_32x32b_reg_layout (qk_instr_shape , self .qk_shape , self .num_warps ))
448451 self .o_layout = gl .constexpr (get_tmem_32x32b_reg_layout (o_instr_shape , self .o_shape , self .num_warps ))
449452 self .o_splitn_layout = gl .constexpr (
450- get_tmem_32x32b_reg_layout ((o_instr_shape [0 ], o_instr_shape [1 ] // self .SPLIT_N_FACTOR , o_instr_shape [2 ]),
451- (self .o_shape [0 ], self .o_shape [1 ] // self .SPLIT_N_FACTOR ), self .num_warps ))
453+ get_tmem_32x32b_reg_layout ((o_instr_shape [0 ], o_instr_shape [1 ] // self .SPLIT_D_FACTOR , o_instr_shape [2 ]),
454+ (self .o_shape [0 ], self .o_shape [1 ] // self .SPLIT_D_FACTOR ), self .num_warps ))
455+ self .mi_2d_layout = gl .constexpr (gl .BlockedLayout ([1 , 1 ], [32 , 1 ], [4 , 1 ], [0 , 1 ]))
456+
457+ self .mi_use_tmem = gl .constexpr (True )
452458
453459 @gluon .jit
454460 def get_program (self ):
@@ -539,7 +545,7 @@ class InnerLoopInfo:
539545 qk_mma_ctx : MMAContext
540546 o_mma_ctx : MMAContext
541547 p_chnl : TensorMemoryChannel
542- mi_chnl : SharedMemoryChannel
548+ mi_chnl : TensorMemoryChannel
543549 li_smem : gl .shared_memory_descriptor
544550 q_smem : gl .shared_memory_descriptor
545551
@@ -552,14 +558,25 @@ def create(config, tile):
552558 o_mma_ctx .channel .initialize_for_consumer ()
553559 o_mma_ctx .channel .mem .index (0 ).store (tile .acc )
554560
555- p_chnl = TensorMemoryChannel ._borrow (qk_mma_ctx .channel .mem , config .qk_shape , config .dtype ,
556- config .p_tmem_layout , num_buffers = 1 , num_consumers = 1 )
561+ # QK and PV MMAs are serialized, which enables borrowing QK's memory.
562+ borrow_tmem = qk_mma_ctx .channel .mem .index (0 )
563+ p_tmem = borrow_tmem .slice (0 , config .BLOCK_N // 2 )
564+ mi_tmem = borrow_tmem .slice (config .BLOCK_N // 2 , 1 )
565+ mi_layout : gl .constexpr = TensorMemoryLayout ([config .SPLIT_M , 1 ], unpacked = False )
566+
567+ p_chnl = TensorMemoryChannel ._borrow (p_tmem , config .qk_shape , config .dtype , config .p_tmem_layout , num_buffers = 1 ,
568+ num_consumers = 1 )
557569 p_chnl .initialize_for_producer ()
558570
559- mi_chnl = SharedMemoryChannel .create ([config .SPLIT_M ], gl .float32 , gl .constexpr (mbarrier .MBarrierLayout ()),
560- num_buffers = 1 )
571+ if config .mi_use_tmem :
572+ mi_chnl = TensorMemoryChannel ._borrow (mi_tmem , [config .SPLIT_M , 1 ], gl .float32 , mi_layout , num_buffers = 1 )
573+ m_i = gl .convert_layout (tile .m_i .expand_dims (1 ), config .mi_2d_layout )
574+ else :
575+ mi_chnl = SharedMemoryChannel .create ([config .SPLIT_M ], gl .float32 , gl .constexpr (mbarrier .MBarrierLayout ()),
576+ num_buffers = 1 )
577+ m_i = tile .m_i
578+ mi_chnl .mem .index (0 ).store (m_i )
561579 mi_chnl .initialize_for_producer ()
562- mi_chnl .mem .index (0 ).store (tile .m_i )
563580
564581 li_smem = gl .allocate_shared_memory (gl .float32 , [config .SPLIT_M ], gl .constexpr (mbarrier .MBarrierLayout ()))
565582 li_smem .store (tile .l_i )
@@ -662,21 +679,66 @@ def _attn_fwd_mma(config, #
662679 mbarrier .invalidate (qk_p_bar )
663680
664681
682+ @gluon .jit
683+ def _add_f32x2 (a , b ):
684+ return gl .inline_asm_elementwise (
685+ """
686+ {
687+ .reg .b64 ra, rb, rc;
688+ mov.b64 ra, { $2, $3 };
689+ mov.b64 rb, { $4, $5 };
690+ add.f32x2 rc, ra, rb;
691+ mov.b64 { $0, $1 }, rc;
692+ }
693+ """ ,
694+ "=r,=r,r,r,r,r" ,
695+ [a , b ],
696+ dtype = gl .float32 ,
697+ is_pure = True ,
698+ pack = 2 ,
699+ )
700+
701+
702+ @gluon .jit
703+ def _mul_f32x2 (a , b ):
704+ return gl .inline_asm_elementwise (
705+ """
706+ {
707+ .reg .b64 ra, rb, rc;
708+ mov.b64 ra, { $2, $3 };
709+ mov.b64 rb, { $4, $5 };
710+ mul.f32x2 rc, ra, rb;
711+ mov.b64 { $0, $1 }, rc;
712+ }
713+ """ ,
714+ "=r,=r,r,r,r,r" ,
715+ [a , b ],
716+ dtype = gl .float32 ,
717+ is_pure = True ,
718+ pack = 2 ,
719+ )
720+
721+
665722@gluon .jit
666723def _attn_fwd_correction_compute (config , mi_consumer , o_consumer , m_i ):
667- m_ij , mi_consumer = mi_consumer .get (gl .constexpr (gl .SliceLayout (1 , config .o_splitn_layout )))
724+ mi_layout : gl .constexpr = gl .SliceLayout (1 , config .o_splitn_layout )
725+ if config .mi_use_tmem :
726+ m_ij , mi_consumer = mi_consumer .get (config .mi_2d_layout )
727+ m_ij = gl .convert_layout (m_ij .reshape ([config .SPLIT_M ]), mi_layout )
728+ else :
729+ m_ij , mi_consumer = mi_consumer .get (mi_layout )
668730 alpha = gl .exp2 (m_i - m_ij )
669731
670732 o_tmem , o_bar , o_consumer = o_consumer .acquire ()
671- if config .SPLIT_N_FACTOR == 1 :
733+ if config .SPLIT_D_FACTOR == 1 :
672734 o = o_tmem .load (config .o_layout )
673- o = o * alpha [:, None ]
735+ o = _mul_f32x2 ( o , alpha [:, None ])
674736 o_tmem .store (o )
675737 else :
676- for i in tl .static_range (config .SPLIT_N_FACTOR ):
677- o_ref = o_tmem .slice (i * config .SPLIT_N , config .SPLIT_N )
738+ for i in tl .static_range (config .SPLIT_D_FACTOR ):
739+ o_ref = o_tmem .slice (i * config .SPLIT_D , config .SPLIT_D )
678740 o = o_ref .load (config .o_splitn_layout )
679- o = o * alpha [:, None ]
741+ o = _mul_f32x2 ( o , alpha [:, None ])
680742 o_ref .store (o )
681743 mbarrier .arrive (o_bar , count = 1 )
682744 return mi_consumer , o_consumer , m_ij
@@ -723,31 +785,48 @@ def _softmax_tile(tile_id: gl.constexpr, config, info, STAGE: gl.constexpr):
723785 p_producer = info .p_chnl .create_producer ()
724786 mi_producer = info .mi_chnl .create_producer ()
725787
726- m_i = info .mi_chnl .mem .index (0 ).load (qk_slice_dim1 )
788+ if config .mi_use_tmem :
789+ m_i = info .mi_chnl .mem .index (0 ).load (config .mi_2d_layout )
790+ m_i = gl .convert_layout (m_i .reshape ([config .SPLIT_M ]), qk_slice_dim1 )
791+ else :
792+ m_i = info .mi_chnl .mem .index (0 ).load (qk_slice_dim1 )
727793 l_i = info .li_smem .load (qk_slice_dim1 )
728794
729795 for start_n in range (lo , hi , config .BLOCK_N ):
730796 qk , qk_consumer = qk_consumer .get (config .qk_layout )
797+ if config .HEAD_DIM == 128 :
798+ p_tmem , p_bar , p_producer = p_producer .acquire ()
799+
731800 if STAGE == 2 :
732801 # Prevent LLVM from hoisting the partial sums, which triggers spilling.
733802 offs_n = gl .inline_asm_elementwise ("mov.b32 $0, $0;" , "=r,r" , [offs_n ], dtype = gl .int32 , is_pure = True ,
734803 pack = 1 )
735804 mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
736- qk = qk * config . qk_scale + gl .where (mask , 0 , - 1.0e6 )
737- m_ij = gl .maximum (m_i , gl .max (qk , 1 ))
738- mi_producer = mi_producer . emplace ( m_ij )
739- qk -= m_ij [:, None ]
805+ qk = gl .where (mask , qk , - 1.0e8 )
806+ m_ij = gl .maximum (m_i , gl .max (qk , 1 ) * config . qk_scale )
807+ if config . mi_use_tmem :
808+ mi_producer = mi_producer . emplace ( gl . convert_layout ( m_ij . expand_dims ( 1 ), config . mi_2d_layout ))
740809 else :
741- m_ij = gl .maximum (m_i , gl .max (qk , 1 ) * config .qk_scale )
742810 mi_producer = mi_producer .emplace (m_ij )
743- qk = qk * config .qk_scale - m_ij [:, None ]
744-
745- p = gl .exp2 (qk )
811+ qk = qk * config .qk_scale - m_ij [:, None ]
746812
747- l_ij = gl .sum (p , 1 )
748- alpha = gl .exp2 (m_i - m_ij )
749-
750- p_producer = p_producer .emplace (p .to (config .dtype ))
813+ if config .HEAD_DIM == 64 :
814+ p = gl .exp2 (qk )
815+ l_ij = gl .sum (p , 1 )
816+ alpha = gl .exp2 (m_i - m_ij )
817+ p_producer = p_producer .emplace (p .to (config .dtype ))
818+ else :
819+ qk0 , qk1 , = qk .reshape ([config .SPLIT_M , 2 , config .BLOCK_N // 2 ]).permute (0 , 2 , 1 ).split ()
820+ p0 = gl .exp2 (qk0 )
821+ p_tmem .slice (0 , config .BLOCK_N // 2 ).store (p0 .to (config .dtype ))
822+ p1 = gl .exp2 (qk1 )
823+ p_tmem .slice (config .BLOCK_N // 2 , config .BLOCK_N // 2 ).store (p1 .to (config .dtype ))
824+ mbarrier .arrive (p_bar , count = 1 )
825+ p = gl .join (p0 , p1 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N ])
826+ p = gl .convert_layout (p , config .qk_layout )
827+
828+ l_ij = gl .sum (p , 1 )
829+ alpha = gl .exp2 (m_i - m_ij )
751830
752831 l_i = l_i * alpha + l_ij
753832 m_i = m_ij
@@ -773,7 +852,7 @@ def _attn_fwd_softmax1(config, #
773852def _attn_fwd_inner (config , info0 , info1 , m_i0 , m_i1 , #
774853 desc_k , desc_v , #
775854 STAGE : gl .constexpr ):
776- num_buffers : gl .constexpr = 2 if config .HEAD_DIM > = 128 else 3
855+ num_buffers : gl .constexpr = 2 if config .HEAD_DIM = = 128 else 3
777856 k_load_ctx = LoadContext .create (desc_k , num_buffers = num_buffers , num_consumers = 2 )
778857 v_load_ctx = LoadContext .create (desc_v , num_buffers = num_buffers , num_consumers = 2 )
779858
@@ -793,7 +872,7 @@ def _attn_fwd_inner(config, info0, info1, m_i0, m_i1, #
793872 _attn_fwd_softmax1 ,
794873 _attn_fwd_mma ,
795874 _attn_fwd_load ,
796- ], [4 , 4 , 1 , 1 ], [192 , 200 , 32 , 32 ])
875+ ], [4 , 4 , 1 , 1 ], [192 , 192 , 32 , 32 ])
797876
798877 k_load_ctx .release ()
799878 v_load_ctx .release ()
@@ -809,8 +888,7 @@ def _gluon_attn(sm_scale, M, Z, H, N_CTX, #
809888 num_warps : gl .constexpr ):
810889 qk_scale = sm_scale
811890 qk_scale *= 1.44269504
812- config = AttentionConfig (qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , dtype , num_warps ,
813- SPLIT_N_FACTOR = triton .cdiv (HEAD_DIM , 64 ))
891+ config = AttentionConfig (qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , dtype , num_warps , SPLIT_D_FACTOR = 2 )
814892
815893 prog = config .get_program ()
816894
@@ -909,7 +987,7 @@ def is_blackwell():
909987@pytest .mark .parametrize ("H" , [2 , 48 ])
910988@pytest .mark .parametrize ("N_CTX" , [256 , 1024 , 4 * 1024 ])
911989@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
912- @pytest .mark .parametrize ("causal" , [True ])
990+ @pytest .mark .parametrize ("causal" , [False , True ])
913991@pytest .mark .parametrize ("dtype" , [torch .float16 ])
914992@pytest .mark .skipif (not is_blackwell (), reason = "Gluon attention is only supported on Blackwell GPUs" )
915993def test_op (Z , H , N_CTX , HEAD_DIM , causal , dtype ):
0 commit comments