@@ -2625,6 +2625,60 @@ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
2625
2625
2626
2626
return is_valid
2627
2627
2628
+ @staticmethod
2629
+ def compute_barrier_flag_size (
2630
+ m : int ,
2631
+ n : int ,
2632
+ l : int ,
2633
+ mma_tiler_mn : Tuple [int , int ],
2634
+ cluster_shape_mn : Tuple [int , int ],
2635
+ sm_count : int ,
2636
+ ) -> int :
2637
+ """
2638
+ Compute the required size for barrier flag tensors used in all-reduce synchronization.
2639
+
2640
+ The barrier flags are used for:
2641
+ 1. Per-tile synchronization during the all-reduce phase
2642
+ 2. Final inter-GPU synchronization barrier
2643
+
2644
+ :param m: Number of rows in the output matrix
2645
+ :type m: int
2646
+ :param n: Number of columns in the output matrix
2647
+ :type n: int
2648
+ :param l: Batch size
2649
+ :type l: int
2650
+ :param mma_tiler_mn: Shape of the MMA tiler (M, N)
2651
+ :type mma_tiler_mn: Tuple[int, int]
2652
+ :param cluster_shape_mn: Cluster dimensions (M, N)
2653
+ :type cluster_shape_mn: Tuple[int, int]
2654
+ :param sm_count: Number of SMs available
2655
+ :type sm_count: int
2656
+
2657
+ :return: Total number of barrier flags needed
2658
+ :rtype: int
2659
+ """
2660
+ # Calculate CTA tile shape accounting for 2-CTA instructions
2661
+ use_2cta_instrs = mma_tiler_mn [0 ] == 256
2662
+ cta_tile_shape_m = mma_tiler_mn [0 ] // (2 if use_2cta_instrs else 1 )
2663
+ cta_tile_shape_n = mma_tiler_mn [1 ]
2664
+
2665
+ # Calculate number of tiles per batch
2666
+ num_tiles_m = (m + cta_tile_shape_m - 1 ) // cta_tile_shape_m
2667
+ num_tiles_n = (n + cta_tile_shape_n - 1 ) // cta_tile_shape_n
2668
+ num_tiles_per_batch = num_tiles_m * num_tiles_n
2669
+
2670
+ # Calculate number of clusters per batch
2671
+ cluster_size = cluster_shape_mn [0 ] * cluster_shape_mn [1 ]
2672
+ num_ctas_per_tile = cluster_size
2673
+
2674
+ # Total tiles across all batches and clusters
2675
+ num_tiles = num_tiles_per_batch * l * num_ctas_per_tile
2676
+
2677
+ # Add extra space for final barrier (one per SM)
2678
+ total_barrier_size = num_tiles + sm_count
2679
+
2680
+ return total_barrier_size
2681
+
2628
2682
@staticmethod
2629
2683
def can_implement (
2630
2684
ab_dtype : Type [cutlass .Numeric ],
@@ -2898,6 +2952,18 @@ def __call__(
2898
2952
barrier_flag_mc_ptr : Optional [cute .Pointer ],
2899
2953
current_stream : cuda .CUstream ,
2900
2954
):
2955
+ if cutlass .const_expr (self ._all_reduce != "none" ):
2956
+ barrier_flag_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
2957
+ self ._m ,
2958
+ self ._n ,
2959
+ self ._l ,
2960
+ self ._mma_tiler_mn ,
2961
+ self ._cluster_shape_mn ,
2962
+ self ._max_active_clusters ,
2963
+ )
2964
+ else :
2965
+ barrier_flag_size = 1 # Dummy size when not used
2966
+
2901
2967
a_tensor = cute .make_tensor (
2902
2968
a_ptr ,
2903
2969
layout = cute .make_ordered_layout (
@@ -2931,11 +2997,11 @@ def __call__(
2931
2997
# on
2932
2998
barrier_flag_tensor = cute .make_tensor (
2933
2999
barrier_flag_ptr ,
2934
- layout = cute .make_ordered_layout ((404 ,), order = (0 ,)),
3000
+ layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
2935
3001
) if barrier_flag_ptr is not None else None
2936
3002
barrier_flag_mc_tensor = cute .make_tensor (
2937
3003
barrier_flag_mc_ptr ,
2938
- layout = cute .make_ordered_layout ((404 ,), order = (0 ,)),
3004
+ layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
2939
3005
) if barrier_flag_mc_ptr is not None else None
2940
3006
2941
3007
# calculate sf_tensor shape and order
0 commit comments