@@ -76,6 +76,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
7676 "EPILOGUE_SUBTILE" : 1 ,
7777 "NUM_CTAS" : 2 ,
7878 "SPLIT_K" : 1 ,
79+ "INTERLEAVE_EPILOGUE" : 1 ,
7980 "ctas_per_cga" : (2 , 1 , 1 ),
8081 "pre_hook" : matmul_tma_set_block_size_hook ,
8182 }
@@ -95,6 +96,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
9596 "EPILOGUE_SUBTILE" : 4 ,
9697 "NUM_CTAS" : 2 ,
9798 "SPLIT_K" : 1 ,
99+ "INTERLEAVE_EPILOGUE" : 0 ,
98100 "ctas_per_cga" : (2 , 1 , 1 ),
99101 "pre_hook" : matmul_tma_set_block_size_hook ,
100102 }
@@ -110,6 +112,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
110112 "EPILOGUE_SUBTILE" : 4 ,
111113 "NUM_CTAS" : 2 ,
112114 "SPLIT_K" : 1 ,
115+ "INTERLEAVE_EPILOGUE" : 1 ,
113116 "ctas_per_cga" : (2 , 1 , 1 ),
114117 "pre_hook" : matmul_tma_set_block_size_hook ,
115118 }
@@ -147,6 +150,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
147150 "EPILOGUE_SUBTILE" : 8 ,
148151 "NUM_CTAS" : 1 ,
149152 "SPLIT_K" : split_k ,
153+ "INTERLEAVE_EPILOGUE" : 0 ,
150154 "ctas_per_cga" : None ,
151155 "pre_hook" : matmul_tma_set_block_size_hook ,
152156 }
@@ -163,6 +167,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
163167 "EPILOGUE_SUBTILE" : 1 ,
164168 "NUM_CTAS" : 1 ,
165169 "SPLIT_K" : split_k ,
170+ "INTERLEAVE_EPILOGUE" : 0 ,
166171 "ctas_per_cga" : None ,
167172 "pre_hook" : matmul_tma_set_block_size_hook ,
168173 }
@@ -180,6 +185,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
180185 "EPILOGUE_SUBTILE" : 4 ,
181186 "NUM_CTAS" : 1 ,
182187 "SPLIT_K" : 1 ,
188+ "INTERLEAVE_EPILOGUE" : 1 ,
183189 "ctas_per_cga" : None ,
184190 "pre_hook" : matmul_tma_set_block_size_hook ,
185191 }
@@ -311,6 +317,7 @@ def compute_wave_score(bm, bn, num_ctas, split_k=1):
311317 "EPILOGUE_SUBTILE" : epilogue_subtile ,
312318 "NUM_CTAS" : num_ctas ,
313319 "SPLIT_K" : split_k ,
320+ "INTERLEAVE_EPILOGUE" : 0 ,
314321 "ctas_per_cga" : (num_ctas , 1 , 1 ) if num_ctas > 1 else None ,
315322 "pre_hook" : matmul_tma_set_block_size_hook ,
316323 }
@@ -359,6 +366,7 @@ def get_cuda_autotune_config():
359366 "EPILOGUE_SUBTILE" : subtile ,
360367 "NUM_CTAS" : num_ctas ,
361368 "SPLIT_K" : split_k ,
369+ "INTERLEAVE_EPILOGUE" : interleave ,
362370 },
363371 num_warps = 4 ,
364372 num_stages = 1 ,
@@ -374,6 +382,7 @@ def get_cuda_autotune_config():
374382 for subtile in [1 , 2 , 4 , 8 ]
375383 for num_ctas in [1 , 2 ]
376384 for split_k in [1 , 4 ]
385+ for interleave in [0 , 1 ]
377386 for g in [1 , 8 , 64 ]
378387 ]
379388
@@ -428,6 +437,7 @@ def preprocess_configs(configs, named_args, **kwargs):
428437 NUM_MMA_GROUPS = conf .kwargs ["NUM_MMA_GROUPS" ]
429438 SPLIT_K = conf .kwargs .get ("SPLIT_K" , 1 )
430439 EPILOGUE_SUBTILE = conf .kwargs ["EPILOGUE_SUBTILE" ]
440+ INTERLEAVE_EPILOGUE = conf .kwargs .get ("INTERLEAVE_EPILOGUE" , 0 )
431441
432442 # Filter out invalid config that causes wrong hardware MMA
433443 if BLOCK_M // NUM_MMA_GROUPS > 128 :
@@ -437,6 +447,10 @@ def preprocess_configs(configs, named_args, **kwargs):
437447 if BLOCK_N % EPILOGUE_SUBTILE != 0 :
438448 continue
439449
450+ # Interleaved epilogue requires NUM_MMA_GROUPS == 2 and SPLIT_K == 1
451+ if INTERLEAVE_EPILOGUE and (NUM_MMA_GROUPS != 2 or SPLIT_K != 1 ):
452+ continue
453+
440454 num_tiles_m = math .ceil (M / BLOCK_M )
441455 num_tiles_n = math .ceil (N / BLOCK_N )
442456 num_mn_tiles = num_tiles_m * num_tiles_n
@@ -527,6 +541,7 @@ def _group_key(c):
527541 c .kwargs ["EPILOGUE_SUBTILE" ],
528542 c .kwargs ["NUM_CTAS" ],
529543 c .kwargs .get ("SPLIT_K" , 1 ),
544+ c .kwargs .get ("INTERLEAVE_EPILOGUE" , 0 ),
530545 )
531546
532547 def _val (c ):
@@ -600,6 +615,7 @@ def _process_tile_epilogue_inner(
600615 NUM_MMA_GROUPS ,
601616 NUM_TMEM_BUFFERS ,
602617 SPLIT_K ,
618+ INTERLEAVE_EPILOGUE ,
603619 c_desc ,
604620 c_smem_buffers ,
605621 tmem_buffers ,
@@ -616,38 +632,96 @@ def _process_tile_epilogue_inner(
616632
617633 slice_size : tl .constexpr = BLOCK_SIZE_N // EPILOGUE_SUBTILE
618634
619- for group_id in tl .static_range (NUM_MMA_GROUPS ):
620- # Wait for TMEM to be filled
621- buf_idx = group_id * NUM_TMEM_BUFFERS + cur_tmem_buf
622-
623- tlx .barrier_wait (tmem_full_bars [buf_idx ], tmem_read_phase )
624-
625- # load the result from TMEM to registers
626- acc_tmem = tmem_buffers [buf_idx ]
627- offs_am = pid_m * BLOCK_SIZE_M + group_id * BLOCK_M_SPLIT
628- for slice_id in tl .static_range (EPILOGUE_SUBTILE ):
629- acc_tmem_subslice = tlx .local_slice (
630- acc_tmem ,
631- [0 , slice_id * slice_size ],
632- [BLOCK_M_SPLIT , slice_size ],
633- )
634- result = tlx .local_load (acc_tmem_subslice )
635- # Signal MMA consumer after each slice
636- tlx .barrier_arrive (tmem_empty_bars [buf_idx ], 1 )
635+ if INTERLEAVE_EPILOGUE :
636+ # Interleaved TMA stores across two groups to improve memory throughput.
637+ # Pattern: wait g0, store g0s0, wait g1, store g1s0,
638+ # then alternate g0/g1 for slices 1-3.
639+ buf_idx_0 = 0 * NUM_TMEM_BUFFERS + cur_tmem_buf
640+ buf_idx_1 = 1 * NUM_TMEM_BUFFERS + cur_tmem_buf
641+ acc_tmem_0 = tmem_buffers [buf_idx_0 ]
642+ acc_tmem_1 = tmem_buffers [buf_idx_1 ]
643+ offs_am_0 = pid_m * BLOCK_SIZE_M + 0 * BLOCK_M_SPLIT
644+ offs_am_1 = pid_m * BLOCK_SIZE_M + 1 * BLOCK_M_SPLIT
645+
646+ # --- Wait for group 0, store group 0 slice 0 ---
647+ tlx .barrier_wait (tmem_full_bars [buf_idx_0 ], tmem_read_phase )
648+ acc_sub = tlx .local_slice (acc_tmem_0 , [0 , 0 * slice_size ], [BLOCK_M_SPLIT , slice_size ])
649+ result = tlx .local_load (acc_sub )
650+ tlx .barrier_arrive (tmem_empty_bars [buf_idx_0 ], 1 )
651+ c = result .to (tlx .dtype_of (c_desc ))
652+ c_smem = c_smem_buffers [0 ]
653+ tlx .local_store (c_smem , c )
654+ tlx .fence_async_shared ()
655+ tlx .async_descriptor_store (c_desc , c_smem , [offs_am_0 , offs_bn + 0 * slice_size ])
656+
657+ # --- Wait for group 1, store group 1 slice 0 ---
658+ tlx .barrier_wait (tmem_full_bars [buf_idx_1 ], tmem_read_phase )
659+ acc_sub = tlx .local_slice (acc_tmem_1 , [0 , 0 * slice_size ], [BLOCK_M_SPLIT , slice_size ])
660+ result = tlx .local_load (acc_sub )
661+ tlx .barrier_arrive (tmem_empty_bars [buf_idx_1 ], 1 )
662+ c = result .to (tlx .dtype_of (c_desc ))
663+ c_smem = c_smem_buffers [1 ]
664+ tlx .local_store (c_smem , c )
665+ tlx .fence_async_shared ()
666+ tlx .async_descriptor_store (c_desc , c_smem , [offs_am_1 , offs_bn + 0 * slice_size ])
667+
668+ # --- Slices 1-3: alternate group 0, group 1 ---
669+ for slice_id in tl .static_range (1 , EPILOGUE_SUBTILE ):
670+ # Group 0
671+ acc_sub = tlx .local_slice (acc_tmem_0 , [0 , slice_id * slice_size ], [BLOCK_M_SPLIT , slice_size ])
672+ result = tlx .local_load (acc_sub )
673+ tlx .barrier_arrive (tmem_empty_bars [buf_idx_0 ], 1 )
637674 c = result .to (tlx .dtype_of (c_desc ))
638- if SPLIT_K == 1 :
639- # Store to SMEM then use async TMA store to global
640- c_smem = c_smem_buffers [group_id ]
641- tlx .async_descriptor_store_wait (0 )
642- tlx .local_store (c_smem , c )
643- tlx .fence_async_shared ()
644- tlx .async_descriptor_store (c_desc , c_smem , [offs_am , offs_bn + slice_id * slice_size ])
645- else :
646- c_desc .store (
647- [offs_am , offs_bn + slice_id * slice_size ],
648- c ,
649- store_reduce = "add" ,
675+ c_smem = c_smem_buffers [0 ]
676+ tlx .async_descriptor_store_wait (1 )
677+ tlx .local_store (c_smem , c )
678+ tlx .fence_async_shared ()
679+ tlx .async_descriptor_store (c_desc , c_smem , [offs_am_0 , offs_bn + slice_id * slice_size ])
680+
681+ # Group 1
682+ acc_sub = tlx .local_slice (acc_tmem_1 , [0 , slice_id * slice_size ], [BLOCK_M_SPLIT , slice_size ])
683+ result = tlx .local_load (acc_sub )
684+ tlx .barrier_arrive (tmem_empty_bars [buf_idx_1 ], 1 )
685+ c = result .to (tlx .dtype_of (c_desc ))
686+ c_smem = c_smem_buffers [1 ]
687+ tlx .async_descriptor_store_wait (1 )
688+ tlx .local_store (c_smem , c )
689+ tlx .fence_async_shared ()
690+ tlx .async_descriptor_store (c_desc , c_smem , [offs_am_1 , offs_bn + slice_id * slice_size ])
691+ else :
692+ for group_id in tl .static_range (NUM_MMA_GROUPS ):
693+ # Wait for TMEM to be filled
694+ buf_idx = group_id * NUM_TMEM_BUFFERS + cur_tmem_buf
695+
696+ tlx .barrier_wait (tmem_full_bars [buf_idx ], tmem_read_phase )
697+
698+ # load the result from TMEM to registers
699+ acc_tmem = tmem_buffers [buf_idx ]
700+ offs_am = pid_m * BLOCK_SIZE_M + group_id * BLOCK_M_SPLIT
701+ for slice_id in tl .static_range (EPILOGUE_SUBTILE ):
702+ acc_tmem_subslice = tlx .local_slice (
703+ acc_tmem ,
704+ [0 , slice_id * slice_size ],
705+ [BLOCK_M_SPLIT , slice_size ],
650706 )
707+ result = tlx .local_load (acc_tmem_subslice )
708+ # Signal MMA consumer after each slice
709+ tlx .barrier_arrive (tmem_empty_bars [buf_idx ], 1 )
710+ c = result .to (tlx .dtype_of (c_desc ))
711+ if SPLIT_K == 1 :
712+ # Store to SMEM then use async TMA store to global
713+ c_smem = c_smem_buffers [group_id ]
714+ tlx .async_descriptor_store_wait (0 )
715+ tlx .local_store (c_smem , c )
716+ tlx .fence_async_shared ()
717+ tlx .async_descriptor_store (c_desc , c_smem , [offs_am , offs_bn + slice_id * slice_size ])
718+ else :
719+ c_desc .store (
720+ [offs_am , offs_bn + slice_id * slice_size ],
721+ c ,
722+ store_reduce = "add" ,
723+ )
724+
651725 # Wait for all TMA stores to complete
652726 tlx .async_descriptor_store_wait (0 )
653727
@@ -854,6 +928,7 @@ def matmul_kernel_tma_ws_blackwell(
854928 EPILOGUE_SUBTILE : tl .constexpr ,
855929 NUM_CTAS : tl .constexpr ,
856930 SPLIT_K : tl .constexpr ,
931+ INTERLEAVE_EPILOGUE : tl .constexpr ,
857932 NUM_SMS : tl .constexpr ,
858933):
859934 # allocate NUM_SMEM_BUFFERS buffers
@@ -943,6 +1018,7 @@ def matmul_kernel_tma_ws_blackwell(
9431018 NUM_MMA_GROUPS = NUM_MMA_GROUPS ,
9441019 NUM_TMEM_BUFFERS = NUM_TMEM_BUFFERS ,
9451020 SPLIT_K = SPLIT_K ,
1021+ INTERLEAVE_EPILOGUE = INTERLEAVE_EPILOGUE ,
9461022 c_desc = c_desc ,
9471023 c_smem_buffers = c_smem_buffers ,
9481024 tmem_buffers = tmem_buffers ,
0 commit comments