diff --git a/flashinfer/cute_dsl/blockscaled_gemm.py b/flashinfer/cute_dsl/blockscaled_gemm.py index d69eda2743..adad756b19 100644 --- a/flashinfer/cute_dsl/blockscaled_gemm.py +++ b/flashinfer/cute_dsl/blockscaled_gemm.py @@ -26,7 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Optional, Tuple, Type, Union +import functools +from typing import Callable, Optional, Tuple, Type, Union, List import cuda.bindings.driver as cuda import cutlass @@ -36,28 +37,35 @@ import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils -import torch -import functools +import cutlass.utils.distributed_helpers as distributed_helpers from cutlass._mlir import ir from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass.cute.runtime import from_dlpack - from cutlass.cutlass_dsl import ( - Int32, - Int64, - Uint8, - Uint64, T, Integer, dsl_user_op, extract_mlir_values, new_from_mlir_values, ) +from cutlass.cute.typing import ( + Int32, + Int64, + Uint8, + Uint64, + Float16, + BFloat16, + Float32, + Float8E4M3FN, + Float8E5M2, + Tensor, +) from cutlass._mlir.dialects import llvm -from flashinfer.utils import get_compute_capability from cutlass.utils.static_persistent_tile_scheduler import WorkTileInfo +import torch + +from flashinfer.utils import get_compute_capability from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm, make_ptr -from typing import Callable, List sizeof_i32 = 4 @@ -395,6 +403,7 @@ def num_tiles_executed(self) -> Int32: - Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type - Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type - Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Matrix C_mc is a multicast C matrix that changes can be broadcasted to all GPUs by multimem instructions (needed only for all-reduce epilogue) - Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, which has M×ceil_div(K, sf_vec_size)×L elements respectively - Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, which has N×ceil_div(K, sf_vec_size)×L elements respectively @@ -404,6 +413,7 @@ def num_tiles_executed(self) -> Int32: - Implements TMA multicast with cluster to reduce L2 memory traffic - Support persistent tile scheduling to better overlap memory load/store with mma between tiles - Support warp specialization to avoid explicit pipelining between mainloop load and mma + - Support all-reduce epilogue with multimem instructions to distribute the workload to all GPUs This GEMM works as follows: 1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. @@ -415,6 +425,13 @@ def num_tiles_executed(self) -> Int32: - Type convert C matrix to output type. - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. +4. All reduce epilogue: + - Load and reduce the 128bit data from all ranks by multimem instructions. + - Broadcast the reduced data to all ranks by multimem instructions. + - current implementation only supports two_shot all-reduce which means each rank only computes a portion of + the output tensor and broadcast the result to all ranks. + - the all-reduce epilogue is only supported when use_tma_store is True. + - the all-reduce epilogue is only supported when c_dtype is Float16, Float32, BFloat16, Float8E4M3FN, Float8E5M2. SM100 tcgen05.mma.kind.block_scale instructions operate as follows: - Read matrix A from SMEM @@ -434,6 +451,16 @@ def num_tiles_executed(self) -> Int32: --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ --mnkl 8192,8192,1024,1 +to run with all reduce epilogue (only on multi-GPU with NVLink): + +.. code-block:: bash + + torchrun --nproc-per-node 8 examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 --all_reduce two_shot + To collect performance with NCU profiler: .. code-block:: bash @@ -443,7 +470,8 @@ def num_tiles_executed(self) -> Int32: --c_dtype Float16 \ --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ --mnkl 8192,8192,1024,1 \ - --warmup_iterations 1 --iterations 10 --skip_ref_check + --warmup_iterations 1 --iterations 10 --skip_ref_check \ + --all_reduce two_shot Constraints: @@ -456,6 +484,7 @@ def num_tiles_executed(self) -> Int32: * Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) * The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. +* when all_reduce is not "none", M and N must be multiple of 128, world_size must be 8 """ @@ -469,6 +498,8 @@ class Sm100BlockScaledPersistentDenseGemmKernel: :type mma_tiler_mn: Tuple[int, int] :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing :type cluster_shape_mn: Tuple[int, int] + :param all_reduce: All-reduce mode, can be "none", "two_shot" + :type all_reduce: str :note: In current version, A and B tensor must have the same data type - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported @@ -507,6 +538,7 @@ def __init__( mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], sm_version: str, + all_reduce="none", ): """Initializes the configuration for a Blackwell dense GEMM kernel. @@ -542,6 +574,8 @@ def __init__( tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE ) + self.all_reduce = all_reduce + self.occupancy = 1 # Set specialized warp ids self.epilog_warp_id = ( @@ -552,17 +586,34 @@ def __init__( ) self.mma_warp_id = 4 self.tma_warp_id = 5 + self.all_reduce_warp_id: Tuple[int, ...] = () + self.all_reduce = "none" + if all_reduce != "none": + self.all_reduce = all_reduce + self.all_reduce_warp_id = (6, 7, 8, 9) self.threads_per_cta = 32 * len( - (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ( + self.mma_warp_id, + self.tma_warp_id, + *self.epilog_warp_id, + *self.all_reduce_warp_id, + ) ) # Set barrier id for cta sync, epilogue sync and tmem ptr sync self.cta_sync_bar_id = 0 self.epilog_sync_bar_id = 1 self.tmem_ptr_sync_bar_id = 2 + self.all_reduce_sync_bar_id = 3 self.smem_capacity = utils.get_smem_capacity_in_bytes(sm_version) SM100_TMEM_CAPACITY_COLUMNS = 512 self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.num_ranks = 1 + self.rank_id = 0 + if all_reduce != "none": + self.num_ranks = torch.distributed.get_world_size() + self.rank_id = torch.distributed.get_rank() + def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -718,6 +769,9 @@ def __call__( alpha_tensor: Optional[cute.Tensor], max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, + c_mc: cute.Tensor = None, + barrier_flag: cute.Tensor = None, + barrier_flag_mc: cute.Tensor = None, ): """Execute the GEMM operation in steps: - Setup static attributes before smem/grid/tma computation @@ -743,6 +797,12 @@ def __call__( :param stream: CUDA stream for asynchronous execution :type stream: cuda.CUstream :param alpha_tensor: Optional 1D tensor of shape (l,) containing per-batch scaling factors. + :param c_mc: Output symmetric tensor C_mc, any write or read to a multicast tensor will be broadcasted to all GPUs + :type c_mc: cute.Tensor + :param barrier_flag: Barrier flag to sync between peers + :type barrier_flag: cute.Tensor + :param barrier_flag_mc: Multicast barrier flag to sync between peers + :type barrier_flag_mc: cute.Tensor :type alpha_tensor: cute.Tensor :raises TypeError: If input data types are incompatible with the MMA instruction. """ @@ -959,6 +1019,9 @@ class SharedStorage: self.c_smem_layout_staged, self.epi_tile, self.tile_sched_params, + c_mc, + barrier_flag, + barrier_flag_mc, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], @@ -994,6 +1057,9 @@ def kernel( c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], epi_tile: cute.Tile, tile_sched_params: MaskedSchedulerParams, + c_mc: cute.Tensor, + barrier_flag: cute.Tensor, + barrier_flag_mc: cute.Tensor, ): """ GPU device kernel performing the Persistent batched GEMM computation. @@ -1012,7 +1078,6 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_c) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 - # # Setup cta/thread coordinates # @@ -1793,6 +1858,24 @@ def kernel( acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() + # + # Allreduce + # + if cutlass.const_expr(self.all_reduce == "two_shot"): + tile_id = Int32( + tile_sched._current_work_linear_idx + * cute.size(self.cluster_shape_mn) + + cute.arch.block_idx_in_cluster() + ) + if warp_idx == self.epilog_warp_id[0]: + cute.arch.cp_async_bulk_wait_group(0, read=False) + # System barrier to make sure that data from each GPU is in memory before allreduce + with cute.arch.elect_one(): + flag = barrier_flag_mc.iterator + tile_id + cute.arch.fence_acq_rel_gpu() + distributed_helpers.spin_lock_multimem_arrive(flag) + cute.arch.fence_proxy(cute.arch.ProxyKind.alias) + # # Advance to next tile # @@ -1848,6 +1931,147 @@ def kernel( else: c_pipeline.producer_tail() + # /////////////////////////////////////////////////////////////////////////////// + # Allreduce warps + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(self.all_reduce == "two_shot"): + if warp_idx >= self.all_reduce_warp_id[0]: + # /////////////////////////////////////////////////////////////////////////////// + # Add persistent tile loop + # /////////////////////////////////////////////////////////////////////////////// + + rank_id = self.rank_id + num_ranks = Int32(self.num_ranks) + lane_id = cute.arch.lane_idx() # noqa + + # tile_sched = utils.StaticPersistentTileScheduler.create( + # tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + # ) + tile_sched = MaskedScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # we want 128bit ld/st for better performance + atom_val = 128 // c_mc.element_type.width + atom_thr_n = self.mma_tiler[1] // atom_val + atom_thr_m = len(self.all_reduce_warp_id) * ( + cute.arch.WARP_SIZE // atom_thr_n + ) + thr_layout = cute.make_layout( + (atom_thr_m, atom_thr_n), stride=(atom_thr_n, 1) + ) + val_layout = cute.make_layout((1, atom_val), stride=(atom_val, 1)) + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), c_mc.element_type + ) + tiled_copy_fake = cute.make_tiled_copy_tv( + copy_atom_load, thr_layout, val_layout + ) + thr_copy_fake = tiled_copy_fake.get_slice( + tidx - self.all_reduce_warp_id[0] * 32 + ) + + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + tile_id = Int32( + tile_sched._current_work_linear_idx + * cute.size(self.cluster_shape_mn) + + cute.arch.block_idx_in_cluster() + ) + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # System barrier to make sure that data from each GPU is in memory before allreduce + if warp_idx == self.all_reduce_warp_id[0]: + with cute.arch.elect_one(): + flag = barrier_flag.iterator + tile_id + # TODO: we may use LDG+STG for spin lock instead of ATOMIC_CAS for better performance. + distributed_helpers.spin_lock_wait(flag, num_ranks) + + cute.arch.barrier( + barrier_id=self.all_reduce_sync_bar_id, + number_of_threads=32 * len(self.all_reduce_warp_id), + ) + # partition and slice at tile level + gC_mc = cute.local_tile( + c_mc, + cute.slice_(self.mma_tiler, (None, None, 0)), + (None, None, None), + ) + tCgC_mc = thr_mma.partition_C(gC_mc) + tCgC_mc_slice = tCgC_mc[((None, None), 0, 0, *mma_tile_coord_mnl)] + + # partition based on the number of GPUs + cta_mma_tile_m = self.mma_tiler[0] // cute.size( + tiled_mma.thr_id.shape + ) + m_local_rank = int(cta_mma_tile_m / self.num_ranks) + tCgC_mc_slice_partitioned = cute.zipped_divide( + tCgC_mc_slice, (m_local_rank, self.mma_tiler[1]) + ) + tCgC_mc_local_rank = cute.slice_( + tCgC_mc_slice_partitioned, ((None, None), (rank_id, 0)) + ) + + # partition at thread level + frgC_mc = thr_copy_fake.partition_S(tCgC_mc_local_rank) + atom, loop_m, loop_n = frgC_mc.shape + for i in cutlass.range_constexpr(loop_m): + for j in cutlass.range_constexpr(loop_n): + mc_ptr = frgC_mc[None, i, j].iterator + x, y, z, w = 0, 0, 0, 0 + if cutlass.const_expr(self.c_dtype == Float16): + x, y, z, w = ( + distributed_helpers.multimem_ld_reduce_8xf16(mc_ptr) + ) + elif cutlass.const_expr(self.c_dtype == Float32): + x, y, z, w = ( + distributed_helpers.multimem_ld_reduce_4xf32(mc_ptr) + ) + elif cutlass.const_expr(self.c_dtype == BFloat16): + x, y, z, w = ( + distributed_helpers.multimem_ld_reduce_8xbf16( + mc_ptr + ) + ) + elif cutlass.const_expr(self.c_dtype == Float8E4M3FN): + x, y, z, w = ( + distributed_helpers.multimem_ld_reduce_16xe4m3( + mc_ptr + ) + ) + elif cutlass.const_expr(self.c_dtype == Float8E5M2): + x, y, z, w = ( + distributed_helpers.multimem_ld_reduce_16xe5m2( + mc_ptr + ) + ) + distributed_helpers.multimem_st_4xb32(mc_ptr, x, y, z, w) + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile, _ = tile_sched.get_current_work() + + cute.arch.barrier( + barrier_id=self.all_reduce_sync_bar_id, + number_of_threads=32 * len(self.all_reduce_warp_id), + ) + # System barrier to make sure all the peer memory transfers are completed. + last_flag_idx = cute.size( + tile_sched.params.problem_layout_ncluster_mnl + ) * cute.size(self.cluster_shape_mn) + if warp_idx == self.all_reduce_warp_id[0]: + with cute.arch.elect_one(): + distributed_helpers.sm_wise_inter_gpu_multimem_barrier( + barrier_flag.iterator + last_flag_idx, + barrier_flag_mc.iterator + last_flag_idx, + self.num_ranks, + ) + def mainloop_s2t_copy_and_partition( self, sSF: cute.Tensor, @@ -2196,6 +2420,7 @@ def is_valid_dtypes_and_scale_factor_vec_size( sf_dtype: Type[cutlass.Numeric], sf_vec_size: int, c_dtype: Type[cutlass.Numeric], + all_reduce: str = "none", ) -> bool: """ Check if the dtypes and sf_vec_size are valid combinations @@ -2246,6 +2471,20 @@ def is_valid_dtypes_and_scale_factor_vec_size( }: is_valid = False + # check if c_dtype is supported by multimem all-reduce + if cutlass.const_expr( + all_reduce != "none" + and c_dtype + not in { + cutlass.Float16, + cutlass.Float32, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + } + ): + is_valid = False + return is_valid @staticmethod @@ -2331,6 +2570,7 @@ def is_valid_tensor_alignment( a_major: str, b_major: str, c_major: str, + all_reduce: str = "none", ) -> bool: """ Check if the tensor alignment is valid @@ -2371,8 +2611,65 @@ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) ): is_valid = False + if all_reduce != "none" and m % 128 != 0 and n % 128 != 0: + is_valid = False + return is_valid + @staticmethod + def compute_barrier_flag_size( + m: int, + n: int, + l: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + sm_count: int, + ) -> int: + """ + Compute the required size for barrier flag tensors used in all-reduce synchronization. + + The barrier flags are used for: + 1. Per-tile synchronization during the all-reduce phase + 2. Final inter-GPU synchronization barrier + + :param m: Number of rows in the output matrix + :type m: int + :param n: Number of columns in the output matrix + :type n: int + :param l: Batch size + :type l: int + :param mma_tiler_mn: Shape of the MMA tiler (M, N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M, N) + :type cluster_shape_mn: Tuple[int, int] + :param sm_count: Number of SMs available + :type sm_count: int + + :return: Total number of barrier flags needed + :rtype: int + """ + # Calculate CTA tile shape accounting for 2-CTA instructions + use_2cta_instrs = mma_tiler_mn[0] == 256 + cta_tile_shape_m = mma_tiler_mn[0] // (2 if use_2cta_instrs else 1) + cta_tile_shape_n = mma_tiler_mn[1] + + # Calculate number of tiles per batch + num_tiles_m = (m + cta_tile_shape_m - 1) // cta_tile_shape_m + num_tiles_n = (n + cta_tile_shape_n - 1) // cta_tile_shape_n + num_tiles_per_batch = num_tiles_m * num_tiles_n + + # Calculate number of clusters per batch + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + num_ctas_per_tile = cluster_size + + # Total tiles across all batches and clusters + num_tiles = num_tiles_per_batch * l * num_ctas_per_tile + + # Add extra space for final barrier (one per SM) + total_barrier_size = num_tiles + sm_count + + return total_barrier_size + @staticmethod def can_implement( ab_dtype: Type[cutlass.Numeric], @@ -2388,6 +2685,8 @@ def can_implement( a_major: str, b_major: str, c_major: str, + all_reduce: str = "none", + process_group: Optional[torch.distributed.ProcessGroup] = None, ) -> bool: """ Check if the gemm can be implemented @@ -2418,6 +2717,8 @@ def can_implement( :type b_major: str :param c_major: The major axis of the C tensor :type c_major: str + :param all_reduce: All-reduce mode, can be "none", "two_shot" + :type all_reduce: str :return: True if the gemm can be implemented, False otherwise :rtype: bool @@ -2425,7 +2726,7 @@ def can_implement( can_implement = True # Skip unsupported types if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( - ab_dtype, sf_dtype, sf_vec_size, c_dtype + ab_dtype, sf_dtype, sf_vec_size, c_dtype, all_reduce ): can_implement = False # Skip unsupported layouts @@ -2440,9 +2741,15 @@ def can_implement( can_implement = False # Skip illegal problem shape for load/store alignment if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment( - m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major, all_reduce ): can_implement = False + + # Check for all reduce constraints + if all_reduce != "none": + # TODO(asamani): expand the logic for mnnvl support + if torch.distributed.get_world_size(process_group) not in [2, 4, 8]: + can_implement = False return can_implement @@ -2574,6 +2881,7 @@ def __init__( cluster_shape_mn: Tuple[int, int], sm_count: int, sm_version: str, + all_reduce: str = "none", ): self._m = m self._n = n @@ -2589,6 +2897,7 @@ def __init__( self._sf_vec_size = sf_vec_size self._mma_tiler_mn = mma_tiler_mn self._cluster_shape_mn = cluster_shape_mn + self._all_reduce = all_reduce if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( ab_dtype, @@ -2604,9 +2913,10 @@ def __init__( a_major, b_major, c_major, + all_reduce, ): raise TypeError( - f"MaskedBatchedMatmulCuteDSL: Unsupported with {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + f"MaskedBatchedMatmulCuteDSL: Unsupported with {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}, {all_reduce}" ) # Compute max active clusters on current device @@ -2630,8 +2940,25 @@ def __call__( masked_m_ptr: cute.Pointer, dst_signals_ptr: Optional[cute.Pointer], alpha_ptr: cute.Pointer, + c_mc_ptr: Optional[cute.Pointer], + barrier_flag_ptr: Optional[cute.Pointer], + barrier_flag_mc_ptr: Optional[cute.Pointer], current_stream: cuda.CUstream, ): + if cutlass.const_expr(self._all_reduce != "none"): + barrier_flag_size = ( + Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size( + self._m, + self._n, + self._l, + self._mma_tiler_mn, + self._cluster_shape_mn, + self._max_active_clusters, + ) + ) + else: + barrier_flag_size = 1 # Dummy size when not used + a_tensor = cute.make_tensor( a_ptr, layout=cute.make_ordered_layout( @@ -2653,6 +2980,33 @@ def __call__( order=(0, 1, 2) if self._c_major == "m" else (1, 0, 2), ), ) + c_mc_tensor = ( + cute.make_tensor( + c_mc_ptr, + layout=cute.make_ordered_layout( + (self._m, self._n, self._l), + order=(0, 1, 2) if self._c_major == "m" else (1, 0, 2), + ), + ) + if c_mc_ptr is not None + else None + ) + barrier_flag_tensor = ( + cute.make_tensor( + barrier_flag_ptr, + layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)), + ) + if barrier_flag_ptr is not None + else None + ) + barrier_flag_mc_tensor = ( + cute.make_tensor( + barrier_flag_mc_ptr, + layout=cute.make_ordered_layout((barrier_flag_size,), order=(0,)), + ) + if barrier_flag_mc_ptr is not None + else None + ) # calculate sf_tensor shape and order def ceil_div(a, b): @@ -2717,6 +3071,7 @@ def ceil_div(a, b): mma_tiler_mn=self._mma_tiler_mn, cluster_shape_mn=self._cluster_shape_mn, sm_version=self._sm_version, + all_reduce=self._all_reduce, )( a_tensor, b_tensor, @@ -2728,6 +3083,9 @@ def ceil_div(a, b): alpha_tensor, self._max_active_clusters, current_stream, + c_mc_tensor, + barrier_flag_tensor, + barrier_flag_mc_tensor, ) @@ -2750,6 +3108,7 @@ def get_cute_dsl_compiled_masked_gemm_kernel( sm_count: int, sm_version: str, enable_dst_signals: bool, + all_reduce: str = "none", ) -> Callable: def get_cute_pointers( input_tensors: Optional[List[torch.tensor]], @@ -2764,10 +3123,17 @@ def get_cute_pointers( masked_m_data_ptr, dst_signals_data_ptr, alpha_data_ptr, - ) = [16 for _ in range(8)] + c_mc_data_ptr, + barrier_flag_data_ptr, + barrier_flag_mc_data_ptr, + ) = [16 for _ in range(11)] if not enable_dst_signals: dst_signals_data_ptr = None + if all_reduce == "none": + c_mc_data_ptr = None + barrier_flag_data_ptr = None + barrier_flag_mc_data_ptr = None else: ( @@ -2779,6 +3145,9 @@ def get_cute_pointers( masked_m_tensor_gpu, dst_signals_tensor_gpu, alpha_tensor_gpu, + c_mc_gpu, + barrier_flag_gpu, + barrier_flag_mc_gpu, ) = input_tensors assert enable_dst_signals == (dst_signals_tensor_gpu is not None) @@ -2792,6 +3161,9 @@ def get_cute_pointers( masked_m_data_ptr, dst_signals_data_ptr, alpha_data_ptr, + c_mc_data_ptr, + barrier_flag_data_ptr, + barrier_flag_mc_data_ptr, ) = ( a_tensor_gpu.data_ptr(), b_tensor_gpu.data_ptr(), @@ -2803,6 +3175,11 @@ def get_cute_pointers( if dst_signals_tensor_gpu is not None else None, alpha_tensor_gpu.data_ptr() if alpha_tensor_gpu is not None else None, + c_mc_gpu.data_ptr() if c_mc_gpu is not None else None, + barrier_flag_gpu.data_ptr() if barrier_flag_gpu is not None else None, + barrier_flag_mc_gpu.data_ptr() + if barrier_flag_mc_gpu is not None + else None, ) a_ptr = make_ptr( @@ -2861,6 +3238,36 @@ def get_cute_pointers( if alpha_data_ptr is not None and alpha_dtype is not None else None ) + c_mc_ptr = ( + make_ptr( + c_dtype, + c_mc_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + if c_mc_data_ptr is not None + else None + ) + barrier_flag_ptr = ( + make_ptr( + cutlass.Int32, + barrier_flag_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + if barrier_flag_data_ptr is not None + else None + ) + barrier_flag_mc_ptr = ( + make_ptr( + cutlass.Int32, + barrier_flag_mc_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + if barrier_flag_mc_data_ptr is not None + else None + ) return [ a_ptr, @@ -2871,6 +3278,9 @@ def get_cute_pointers( masked_m_ptr, dst_signals_ptr, alpha_ptr, + c_mc_ptr, + barrier_flag_ptr, + barrier_flag_mc_ptr, ] kernel = cute.compile( @@ -2891,6 +3301,7 @@ def get_cute_pointers( cluster_shape_mn=cluster_shape_mn, sm_count=sm_count, sm_version=sm_version, + all_reduce=all_reduce, ), *get_cute_pointers(None), cutlass_torch.current_stream(), @@ -2905,6 +3316,12 @@ def tensor_api( dst_signals_tensor_gpu: torch.Tensor, c_tensor_gpu: Optional[torch.Tensor] = None, alpha_tensor_gpu: Optional[torch.Tensor] = None, + c_mc_gpu: Optional[Tensor] = None, + c_mc_torch: Optional[torch.Tensor] = None, + barrier_flag_gpu: Optional[Tensor] = None, + barrier_flag_torch: Optional[torch.Tensor] = None, + barrier_flag_mc_gpu: Optional[Tensor] = None, + barrier_flag_mc_torch: Optional[torch.Tensor] = None, ): if c_tensor_gpu is None: # fp4 gemm output is not supported @@ -2929,6 +3346,9 @@ def tensor_api( masked_m_tensor_gpu, dst_signals_tensor_gpu, alpha_tensor_gpu, + c_mc_torch, + barrier_flag_torch, + barrier_flag_mc_torch, ] ), current_stream, @@ -2951,6 +3371,13 @@ def grouped_gemm_nt_masked( sf_vec_size: int, dst_signals: Optional[torch.Tensor] = None, sm_count: Optional[int] = None, + all_reduce: str = "none", + out_mc: Optional[Tensor] = None, + out_mc_torch: Optional[torch.Tensor] = None, + barrier_flag: Optional[Tensor] = None, + barrier_flag_mc: Optional[Tensor] = None, + barrier_flag_torch: Optional[torch.Tensor] = None, + barrier_flag_mc_torch: Optional[torch.Tensor] = None, **kwargs, ): """ @@ -3031,6 +3458,7 @@ def grouped_gemm_nt_masked( sm_count=sm_count, sm_version=f"sm_{major}{minor}", enable_dst_signals=dst_signals is not None, + all_reduce=all_reduce, )( a_tensor_gpu=a_torch, b_tensor_gpu=b_torch, @@ -3040,4 +3468,10 @@ def grouped_gemm_nt_masked( masked_m_tensor_gpu=masked_m, dst_signals_tensor_gpu=dst_signals, alpha_tensor_gpu=alpha, + c_mc_gpu=out_mc, + c_mc_torch=out_mc_torch, + barrier_flag_gpu=barrier_flag, + barrier_flag_torch=barrier_flag_torch, + barrier_flag_mc_gpu=barrier_flag_mc, + barrier_flag_mc_torch=barrier_flag_mc_torch, ) diff --git a/tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py b/tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py new file mode 100644 index 0000000000..e3448ecfe0 --- /dev/null +++ b/tests/test_cute_dsl_blockscaled_gemm_allreduce_two_shot.py @@ -0,0 +1,559 @@ +import multiprocessing as mp +import pytest +import socket +from typing import Any, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import cutlass.torch as cutlass_torch + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem + +from flashinfer.cute_dsl.blockscaled_gemm import ( + Sm100BlockScaledPersistentDenseGemmKernel, # not used in python interface + grouped_gemm_nt_masked, # deepgemm-like python interface for DLFW integration + create_scale_factor_tensor, +) +from flashinfer.cute_dsl.utils import ( + get_cutlass_dtype, + is_cute_dsl_available, +) + + +def create_mc_tensor(torch_tensor_cpu, dtype, is_dynamic_layout=True): + m, n, l = torch_tensor_cpu.shape + + # Create flat symm_mem buffer + total_elements = m * n * l + torch_symm_flat = symm_mem.empty( + (total_elements,), device="cuda", dtype=torch_tensor_cpu.dtype + ) + + # Reshape to match input's stride pattern using as_strided + torch_symm_tensor = torch_symm_flat.as_strided( + size=torch_tensor_cpu.shape, stride=torch_tensor_cpu.stride() + ) + torch_symm_tensor.copy_(torch_tensor_cpu) + + symm = symm_mem.rendezvous(torch_symm_flat, group=dist.group.WORLD.group_name) + mc_ptr = symm.multicast_ptr + + # Create MC tensor with same stride + torch_tensor_mc_flat = cutlass_torch.as_tensor( + mc_ptr, (total_elements,), torch_tensor_cpu.dtype + ) + torch_tensor_mc = torch_tensor_mc_flat.as_strided( + size=torch_tensor_cpu.shape, stride=torch_tensor_cpu.stride() + ) + + cute_tensor_mc = from_dlpack(torch_tensor_mc, assumed_align=16) + + if is_dynamic_layout: + for i, stride in enumerate(torch_tensor_mc.stride()): + if stride == 1: + leading_dim = i + break + cute_tensor_mc = cute_tensor_mc.mark_layout_dynamic(leading_dim=leading_dim) + + torch_tensor_gpu = torch_symm_tensor + cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) + cute_tensor.element_type = dtype + + if is_dynamic_layout: + for i, stride in enumerate(torch_tensor_gpu.stride()): + if stride == 1: + leading_dim = i + break + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + + cute_tensor = cutlass_torch.convert_cute_tensor( + torch_tensor_gpu, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + return cute_tensor, cute_tensor_mc, torch_tensor_gpu, torch_tensor_mc + + +def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count): + barrier_size = Sm100BlockScaledPersistentDenseGemmKernel.compute_barrier_flag_size( + m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count + ) + barrier_flag = symm_mem.empty((barrier_size,), device="cuda", dtype=torch.int32) + + barrier_flag.fill_(0) + symm = symm_mem.rendezvous(barrier_flag, group=dist.group.WORLD.group_name) + barrier_flag_mc_ptr = symm.multicast_ptr + + barrier_flag_memref = from_dlpack(barrier_flag) + barrier_flag_memref = barrier_flag_memref.mark_layout_dynamic() + barrier_flag_mc_torch = cutlass_torch.as_tensor( + barrier_flag_mc_ptr, barrier_flag.shape, barrier_flag.dtype + ) + barrier_flag_mc_memref = from_dlpack( + barrier_flag_mc_torch, + ) + barrier_flag_mc_memref = barrier_flag_mc_memref.mark_layout_dynamic() + barrier_flag_torch = barrier_flag + return ( + barrier_flag_memref, + barrier_flag_mc_memref, + barrier_flag_torch, + barrier_flag_mc_torch, + ) + + +def run_blockscaled_gemm_all_reduce_python_interface( + lm: Tuple[int, int], + kn: Tuple[int, int], + ab_dtype: cutlass.dtype, + sf_dtype: cutlass.dtype, + sf_vec_size: int, + c_dtype: cutlass.dtype, + a_major: str, + b_major: str, + c_major: str, + fuse_alpha: bool, + alpha_dtype: cutlass.dtype, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + sm_count: int, + tolerance: float, + iterations: int, + enable_dst_signals: int, + all_reduce: str, + rank: int, + world_size: int, +): + torch.manual_seed(42) + device = torch.device("cuda", rank) + major, minor = torch.cuda.get_device_capability(device) + + if not (major == 10 and minor == 0): + pytest.skip("Cute-dsl backend is only supported on SM100.") + if enable_dst_signals and (sm_count is None): + pytest.skip("dst_signals require sm_count") + + l, m = lm + k, n = kn + + if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + get_cutlass_dtype(ab_dtype), + get_cutlass_dtype(sf_dtype), + sf_vec_size, + get_cutlass_dtype(c_dtype), + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + pytest.skip( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not (a_major == "k" and b_major == "k" and c_major == "n"): + # not supported since we try to align deepgemm for now + pytest.skip( + f"Skip non deepgemm-like cases {a_major}, {b_major}, {c_major}. Might be added later" + ) + + a_ref = cutlass_torch.matrix( + l, m, k, a_major == "m", cutlass.Float32, device=device + ) + b_ref = cutlass_torch.matrix( + l, n, k, b_major == "n", cutlass.Float32, device=device + ) + c_ref = cutlass_torch.matrix( + l, + m, + n, + c_major == "m", + cutlass.Float32, + device=device, + init_type=cutlass_torch.TensorInitType.SCALAR, + init_config=cutlass_torch.ScalarInitConfig(value=0.0), + ) + a_tensor, a_torch = cutlass_torch.cute_tensor_like( + a_ref, + get_cutlass_dtype(ab_dtype), + is_dynamic_layout=True, + assumed_align=16, + ) + b_tensor, b_torch = cutlass_torch.cute_tensor_like( + b_ref, + get_cutlass_dtype(ab_dtype), + is_dynamic_layout=True, + assumed_align=16, + ) + c_tensor, c_tensor_mc, c_torch, c_torch_mc = create_mc_tensor( + c_ref, + get_cutlass_dtype(c_dtype), + # (1 if c_major == "n" else 0), + is_dynamic_layout=True, + ) + alpha_tensor = ( + torch.randn(l, dtype=torch.float32, device=device) if fuse_alpha else None + ) + ( + barrier_flag_memref, + barrier_flag_mc_memref, + barrier_flag_torch, + barrier_flag_mc_torch, + ) = create_barrier_flags( + m, + n, + l, + mma_tiler_mn, + cluster_shape_mn, + sm_count, + ) + # for deepgemm-like python interface + if ab_dtype == "float4_e2m1fn": + m, k, l = a_torch.shape + n, k, l = b_torch.shape + # slice into half after flatten + half_len_a = a_torch.numel() // 2 + half_len_b = b_torch.numel() // 2 + a_torch = ( + a_torch.permute(2, 0, 1) + .flatten()[:half_len_a] + .reshape(l, m, k // 2) + .permute(1, 2, 0) + ) + b_torch = ( + b_torch.permute(2, 0, 1) + .flatten()[:half_len_b] + .reshape(l, n, k // 2) + .permute(1, 2, 0) + ) + + sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device + ) + sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device + ) + if rank == 0: + masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device) + else: + masked_m_tensor = torch.empty((l,), dtype=torch.int32, device=device) + torch.distributed.broadcast(masked_m_tensor, src=0) + for _ in range(iterations): + dst_signals = ( + torch.zeros((l,), dtype=torch.uint32, device="cuda") + if enable_dst_signals + else None + ) + + # deepgemm-like python interface: fp4 packed, for DLFW integration + grouped_gemm_nt_masked( + (a_torch, sfa_torch), + (b_torch, sfb_torch), + c_torch, + masked_m_tensor, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + alpha=alpha_tensor, + alpha_dtype=alpha_dtype, + sm_count=sm_count, + dst_signals=dst_signals, + all_reduce=all_reduce, + out_mc=c_tensor_mc, + out_mc_torch=c_torch_mc, + barrier_flag=barrier_flag_memref, + barrier_flag_mc=barrier_flag_mc_memref, + barrier_flag_torch=barrier_flag_torch, + barrier_flag_mc_torch=barrier_flag_mc_torch, + ) + + if enable_dst_signals: + assert torch.all(dst_signals == sm_count), f"{dst_signals}" + + # compute ref output + if not fuse_alpha: + alpha_tensor = torch.ones(l, dtype=torch.float32, device=device) + res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) + res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) + ref = torch.einsum("mkl,nkl->mnl", res_a, res_b) + ref = torch.einsum("mnl,l->mnl", ref, alpha_tensor) + ref = ref.contiguous() + torch.distributed.all_reduce( + ref, op=torch.distributed.ReduceOp.SUM, group=dist.group.WORLD + ) + # Convert c back to f32 for comparison. + ref = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0) + cute.testing.convert( + c_tensor, + from_dlpack(c_ref, assumed_align=16).mark_layout_dynamic( + leading_dim=(1 if c_major == "n" else 0) + ), + ) + if c_dtype in ("float32", "float16", "bfloat16"): + for i in range(l): + # skip testing c_ref & ref + torch.testing.assert_close( + c_ref[: masked_m_tensor[i].item(), :, i], + ref[: masked_m_tensor[i].item(), :, i], + atol=tolerance, + rtol=1e-02, + ) + elif c_dtype in ("float8_e5m2", "float8_e4m3fn"): + # Convert ref : f32 -> f8 -> f32 + ref_f8_ = torch.empty(*(l, m, n), dtype=torch.uint8, device=device).permute( + 1, 2, 0 + ) + ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + ref_f8.element_type = get_cutlass_dtype(c_dtype) + ref = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0) + ref_tensor = from_dlpack(ref, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + cute.testing.convert(ref_tensor, ref_f8) + cute.testing.convert(ref_f8, ref_tensor) + for i in range(l): + # skip testing c_ref & ref + torch.testing.assert_close( + c_ref[: masked_m_tensor[i].item(), :, i], + ref[: masked_m_tensor[i].item(), :, i], + atol=tolerance, + rtol=1e-02, + ) + + +def _run_correctness_worker( + world_size, + rank, + distributed_init_port, + lm, + kn, + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + a_major, + b_major, + c_major, + fuse_alpha, + alpha_dtype, + mma_tiler_mn, + cluster_shape_mn, + sm_count, + tolerance, + iterations, + enable_dst_signals, + all_reduce, +): + assert rank >= 0 + torch.cuda.set_device(rank) + device = torch.device("cuda", rank) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + rank=rank, + world_size=world_size, + device_id=device, + init_method=distributed_init_method, + ) + group = dist.group.WORLD + rank_id = torch.distributed.get_rank() + + try: + run_blockscaled_gemm_all_reduce_python_interface( + lm=lm, + kn=kn, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + c_dtype=c_dtype, + a_major=a_major, + b_major=b_major, + c_major=c_major, + fuse_alpha=fuse_alpha, + alpha_dtype=alpha_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + tolerance=tolerance, + iterations=iterations, + sm_count=sm_count, + enable_dst_signals=enable_dst_signals, + all_reduce=all_reduce, + rank=rank, + world_size=world_size, + ) + except Exception as e: + print(f"Rank {rank_id}: Exception during test: {e}") + raise + finally: + torch.distributed.barrier(group) + torch.distributed.destroy_process_group(group) + + +def get_open_port() -> int: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("::1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, test_target: Any, target_args: tuple = () +) -> None: + mp.set_start_method("spawn", force=True) + + procs = [] + distributed_init_port = get_open_port() + for i in range(world_size): + proc_args = (world_size, i, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") + proc.start() + procs.append(proc) + + for i in range(world_size): + procs[i].join() + assert procs[i].exitcode == 0, ( + f"Process {i} failed with exit code {procs[i].exitcode}" + ) + + +@pytest.mark.skipif( + not is_cute_dsl_available(), reason="Please `pip install nvidia-cutlass-dsl`" +) +@pytest.mark.parametrize("world_size", [8]) +@pytest.mark.parametrize("lm", [(1, 1024), (2, 512), (4, 256)]) +@pytest.mark.parametrize("kn", [(7168, 4096), (2048, 7168)]) +@pytest.mark.parametrize( + "ab_dtype,sf_dtype,c_dtype,sf_vec_size", + [ + ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32), + ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16), + ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16), + ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16), + ("float4_e2m1fn", "float8_e4m3fn", "float16", 16), + ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16), + ("float4_e2m1fn", "float8_e4m3fn", "float32", 16), + ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32), + ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32), + ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32), + # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32), + ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32), + ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32), + ("float8_e5m2", "float8_e8m0fnu", "float16", 32), + ("float8_e5m2", "float8_e8m0fnu", "float32", 32), + # ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32), + ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32), + ], +) +@pytest.mark.parametrize("a_major", ["k"]) +@pytest.mark.parametrize("b_major", ["k"]) +@pytest.mark.parametrize("c_major", ["n"]) +@pytest.mark.parametrize("fuse_alpha", [False, True]) +@pytest.mark.parametrize("alpha_dtype", ["float32"]) +@pytest.mark.parametrize("mma_tiler_mn", [(128, 128)]) +@pytest.mark.parametrize("cluster_shape_mn", [(1, 1)]) +@pytest.mark.parametrize("sm_count", [148]) +@pytest.mark.parametrize("tolerance", [1e-01]) +@pytest.mark.parametrize("iterations", [1]) +@pytest.mark.parametrize("enable_dst_signals", [False, True]) +@pytest.mark.parametrize("all_reduce", ["two_shot"]) +def test_cute_dsl_blockscaled_gemm_allreduce_two_shot( + world_size, + lm, + kn, + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + a_major, + b_major, + c_major, + fuse_alpha, + alpha_dtype, + mma_tiler_mn, + cluster_shape_mn, + sm_count, + tolerance, + iterations, + enable_dst_signals, + all_reduce, +): + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + pytest.skip( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) + major, minor = torch.cuda.get_device_capability(torch.device("cuda:0")) + if not (major == 10 and minor == 0): + pytest.skip("Cute-dsl backend is only supported on SM100.") + if enable_dst_signals and (sm_count is None): + pytest.skip("dst_signals require sm_count") + + l, m = lm + k, n = kn + if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + get_cutlass_dtype(ab_dtype), + get_cutlass_dtype(sf_dtype), + sf_vec_size, + get_cutlass_dtype(c_dtype), + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + pytest.skip( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not (a_major == "k" and b_major == "k" and c_major == "n"): + # not supported since we try to align deepgemm for now + pytest.skip( + f"Skip non deepgemm-like cases {a_major}, {b_major}, {c_major}. Might be added later" + ) + print(f"Running test for world_size={world_size}") + multi_process_parallel( + world_size, + _run_correctness_worker, + target_args=( + lm, + kn, + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + a_major, + b_major, + c_major, + fuse_alpha, + alpha_dtype, + mma_tiler_mn, + cluster_shape_mn, + sm_count, + tolerance, + iterations, + enable_dst_signals, + all_reduce, + ), + ) + print(f"cute_dsl_blockscaled_gemm_allreduce_two_shot on {world_size} GPUs: OK")