@@ -150,7 +150,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
150150 "EPILOGUE_SUBTILE" : 8 ,
151151 "NUM_CTAS" : 1 ,
152152 "SPLIT_K" : split_k ,
153- "INTERLEAVE_EPILOGUE" : 0 ,
153+ "INTERLEAVE_EPILOGUE" : 1 ,
154154 "ctas_per_cga" : None ,
155155 "pre_hook" : matmul_tma_set_block_size_hook ,
156156 }
@@ -167,7 +167,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
167167 "EPILOGUE_SUBTILE" : 1 ,
168168 "NUM_CTAS" : 1 ,
169169 "SPLIT_K" : split_k ,
170- "INTERLEAVE_EPILOGUE" : 0 ,
170+ "INTERLEAVE_EPILOGUE" : 1 ,
171171 "ctas_per_cga" : None ,
172172 "pre_hook" : matmul_tma_set_block_size_hook ,
173173 }
@@ -447,8 +447,8 @@ def preprocess_configs(configs, named_args, **kwargs):
447447 if BLOCK_N % EPILOGUE_SUBTILE != 0 :
448448 continue
449449
450- # Interleaved epilogue requires NUM_MMA_GROUPS == 2 and SPLIT_K == 1
451- if INTERLEAVE_EPILOGUE and ( NUM_MMA_GROUPS != 2 or SPLIT_K != 1 ) :
450+ # Interleaved epilogue requires NUM_MMA_GROUPS == 2
451+ if INTERLEAVE_EPILOGUE and NUM_MMA_GROUPS != 2 :
452452 continue
453453
454454 num_tiles_m = math .ceil (M / BLOCK_M )
@@ -631,6 +631,7 @@ def _process_tile_epilogue_inner(
631631 BLOCK_M_SPLIT : tl .constexpr = BLOCK_SIZE_M // NUM_MMA_GROUPS
632632
633633 slice_size : tl .constexpr = BLOCK_SIZE_N // EPILOGUE_SUBTILE
634+ STORE_REDUCE : tl .constexpr = "add" if SPLIT_K > 1 else ""
634635
635636 if INTERLEAVE_EPILOGUE :
636637 # Interleaved TMA stores across two groups to improve memory throughput.
@@ -652,7 +653,13 @@ def _process_tile_epilogue_inner(
652653 c_smem = c_smem_buffers [0 ]
653654 tlx .local_store (c_smem , c )
654655 tlx .fence_async_shared ()
655- tlx .async_descriptor_store (c_desc , c_smem , [offs_am_0 , offs_bn + 0 * slice_size ])
656+ tlx .async_descriptor_store (
657+ c_desc ,
658+ c_smem ,
659+ [offs_am_0 , offs_bn + 0 * slice_size ],
660+ store_reduce = STORE_REDUCE ,
661+ eviction_policy = "evict_first" ,
662+ )
656663
657664 # --- Wait for group 1, store group 1 slice 0 ---
658665 tlx .barrier_wait (tmem_full_bars [buf_idx_1 ], tmem_read_phase )
@@ -663,7 +670,13 @@ def _process_tile_epilogue_inner(
663670 c_smem = c_smem_buffers [1 ]
664671 tlx .local_store (c_smem , c )
665672 tlx .fence_async_shared ()
666- tlx .async_descriptor_store (c_desc , c_smem , [offs_am_1 , offs_bn + 0 * slice_size ])
673+ tlx .async_descriptor_store (
674+ c_desc ,
675+ c_smem ,
676+ [offs_am_1 , offs_bn + 0 * slice_size ],
677+ store_reduce = STORE_REDUCE ,
678+ eviction_policy = "evict_first" ,
679+ )
667680
668681 # --- Slices 1-3: alternate group 0, group 1 ---
669682 for slice_id in tl .static_range (1 , EPILOGUE_SUBTILE ):
@@ -676,7 +689,13 @@ def _process_tile_epilogue_inner(
676689 tlx .async_descriptor_store_wait (1 )
677690 tlx .local_store (c_smem , c )
678691 tlx .fence ("async_shared" )
679- tlx .async_descriptor_store (c_desc , c_smem , [offs_am_0 , offs_bn + slice_id * slice_size ])
692+ tlx .async_descriptor_store (
693+ c_desc ,
694+ c_smem ,
695+ [offs_am_0 , offs_bn + slice_id * slice_size ],
696+ store_reduce = STORE_REDUCE ,
697+ eviction_policy = "evict_first" ,
698+ )
680699
681700 # Group 1
682701 acc_sub = tlx .local_slice (acc_tmem_1 , [0 , slice_id * slice_size ], [BLOCK_M_SPLIT , slice_size ])
@@ -687,7 +706,13 @@ def _process_tile_epilogue_inner(
687706 tlx .async_descriptor_store_wait (1 )
688707 tlx .local_store (c_smem , c )
689708 tlx .fence ("async_shared" )
690- tlx .async_descriptor_store (c_desc , c_smem , [offs_am_1 , offs_bn + slice_id * slice_size ])
709+ tlx .async_descriptor_store (
710+ c_desc ,
711+ c_smem ,
712+ [offs_am_1 , offs_bn + slice_id * slice_size ],
713+ store_reduce = STORE_REDUCE ,
714+ eviction_policy = "evict_first" ,
715+ )
691716 else :
692717 for group_id in tl .static_range (NUM_MMA_GROUPS ):
693718 # Wait for TMEM to be filled
@@ -708,19 +733,17 @@ def _process_tile_epilogue_inner(
708733 # Signal MMA consumer after each slice
709734 tlx .barrier_arrive (tmem_empty_bars [buf_idx ], 1 )
710735 c = result .to (tlx .dtype_of (c_desc ))
711- if SPLIT_K == 1 :
712- # Store to SMEM then use async TMA store to global
713- c_smem = c_smem_buffers [group_id ]
714- tlx .async_descriptor_store_wait (0 )
715- tlx .local_store (c_smem , c )
716- tlx .fence_async_shared ()
717- tlx .async_descriptor_store (c_desc , c_smem , [offs_am , offs_bn + slice_id * slice_size ])
718- else :
719- c_desc .store (
720- [offs_am , offs_bn + slice_id * slice_size ],
721- c ,
722- store_reduce = "add" ,
723- )
736+ c_smem = c_smem_buffers [group_id ]
737+ tlx .async_descriptor_store_wait (0 )
738+ tlx .local_store (c_smem , c )
739+ tlx .fence_async_shared ()
740+ tlx .async_descriptor_store (
741+ c_desc ,
742+ c_smem ,
743+ [offs_am , offs_bn + slice_id * slice_size ],
744+ store_reduce = STORE_REDUCE ,
745+ eviction_policy = "evict_first" ,
746+ )
724747
725748 # Wait for all TMA stores to complete
726749 tlx .async_descriptor_store_wait (0 )
@@ -881,13 +904,15 @@ def _process_tile_producer_inner(
881904 tlx .barrier_wait (A_smem_empty_bars [a_buf ], phase ^ 1 )
882905 offs_am = pid_m * BLOCK_SIZE_M
883906 tlx .barrier_expect_bytes (A_smem_full_bars [a_buf ], dsize * BLOCK_M_SPLIT * BLOCK_SIZE_K )
884- tlx .async_descriptor_load (a_desc , buffers_A [a_buf ], [offs_am , offs_k ], A_smem_full_bars [a_buf ])
907+ tlx .async_descriptor_load (a_desc , buffers_A [a_buf ], [offs_am , offs_k ], A_smem_full_bars [a_buf ],
908+ eviction_policy = "evict_last" )
885909
886910 # Load B once per K iteration (shared across all subtiles)
887911 last_a_buf = (NUM_MMA_GROUPS - 1 ) * NUM_SMEM_BUFFERS + buf
888912 tlx .barrier_wait (A_smem_empty_bars [last_a_buf ], phase ^ 1 )
889913 tlx .barrier_expect_bytes (B_smem_full_bars [buf ], expected_bytes )
890- tlx .async_descriptor_load (b_desc , buffers_B [buf ], [offs_k , offs_bn ], B_smem_full_bars [buf ])
914+ tlx .async_descriptor_load (b_desc , buffers_B [buf ], [offs_k , offs_bn ], B_smem_full_bars [buf ],
915+ eviction_policy = "evict_last" )
891916
892917 # Load all remaining A subtiles for this K iteration
893918 for group_id in tl .static_range (1 , NUM_MMA_GROUPS ):
@@ -898,7 +923,8 @@ def _process_tile_producer_inner(
898923 offs_am2 = offs_am + group_id * BLOCK_M_SPLIT
899924
900925 tlx .barrier_expect_bytes (A_smem_full_bars [a_buf ], dsize * BLOCK_M_SPLIT * BLOCK_SIZE_K )
901- tlx .async_descriptor_load (a_desc , buffers_A [a_buf ], [offs_am2 , offs_k ], A_smem_full_bars [a_buf ])
926+ tlx .async_descriptor_load (a_desc , buffers_A [a_buf ], [offs_am2 , offs_k ], A_smem_full_bars [a_buf ],
927+ eviction_policy = "evict_last" )
902928
903929 smem_accum_cnt += 1
904930
0 commit comments