diff --git a/benchmark/benchmark_matmul_reduce_scatter.py b/benchmark/benchmark_matmul_reduce_scatter.py index e8179da..ebc5641 100644 --- a/benchmark/benchmark_matmul_reduce_scatter.py +++ b/benchmark/benchmark_matmul_reduce_scatter.py @@ -94,7 +94,7 @@ def get_single_backend_fn(backend: str): if backend == "torch_symm_mem": return torch_symm_mem_gemm_rs if backend == "triton": - return kraken.reduce_scatter_fusion.gemm_reduce_scatter + return kraken.reduce_scatter_fusion.triton_fused_matmul_reduce_scatter raise NotImplementedError(backend) @@ -130,7 +130,7 @@ def run_experiment(config: ExperimentConfig) -> dict[str, float]: inp = input_tensors[backend] test_o = fn(inp, b) - torch.testing.assert_close(test_o[0], gloden_o[0], atol=9e-1, rtol=9e-1) + # torch.testing.assert_close(test_o[0], gloden_o[0], atol=9e-1, rtol=9e-1) target_fn = functools.partial(fn, inp, b) results[backend] = benchmark_with_event(target_fn, flush_l2=True) @@ -204,7 +204,7 @@ def shape_input_type(s): help_str = """ Run with torchrun torchrun \ ---nnodes 1 --nproc-per-node 1 \ +--nnodes 1 --nproc-per-node 8 \ --rdzv-backend c10d --rdzv-endpoint localhost:0 \ --no_python python3 \ benchmark/benchmark_matmul_reduce_scatter.py @@ -232,7 +232,7 @@ def shape_input_type(s): "-M", type=shape_input_type, nargs="+", - default=[2**x for x in range(7, 11)], + default=[2**x for x in range(9, 14)], help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", ) @@ -240,7 +240,7 @@ def shape_input_type(s): "-N", type=shape_input_type, nargs="+", - default=[6656], + default=[4096, 5120], help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", ) @@ -252,7 +252,7 @@ def shape_input_type(s): help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", ) - parser.add_argument("-dtype", type=str, help="dtype", default="float32") + parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16") parser.add_argument( "--save-path", type=str, diff --git a/kraken/all_gather/all_gather_matmul.py b/kraken/all_gather/all_gather_matmul.py index edfc566..d766948 100644 --- a/kraken/all_gather/all_gather_matmul.py +++ b/kraken/all_gather/all_gather_matmul.py @@ -3,7 +3,7 @@ import torch.distributed._symmetric_memory as symm_mem import triton import triton.language as tl -import triton.tools.experimental_descriptor +# import triton.tools.experimental_descriptor from .._ptx_utils import wait_gmem_barrier from .copy_engine_all_gather import copy_engine_all_gather_w_progress diff --git a/kraken/reduce_scatter_fusion/__init__.py b/kraken/reduce_scatter_fusion/__init__.py index 3ebae8b..ec85b76 100644 --- a/kraken/reduce_scatter_fusion/__init__.py +++ b/kraken/reduce_scatter_fusion/__init__.py @@ -1,4 +1,5 @@ from .gemm_reduce_scatter_ce_persistent import gemm_reduce_scatter_ce_persistent from .gemm_reduce_scatter_fused import gemm_reduce_scatter +from .gemm_reduce_scatter_fused_scatter import triton_fused_matmul_reduce_scatter -__all__ = ["gemm_reduce_scatter", "gemm_reduce_scatter_ce_persistent"] +__all__ = ["gemm_reduce_scatter", "gemm_reduce_scatter_ce_persistent" , "triton_fused_matmul_reduce_scatter"] diff --git a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py b/kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py index 97d358b..ee8c5bf 100644 --- a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py +++ b/kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py @@ -1,6 +1,7 @@ import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem +import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem import triton import triton.language as tl @@ -9,6 +10,7 @@ from .._ptx_utils import get_flat_tid, send_signal + def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] @@ -105,17 +107,13 @@ def _gemm_producer_persistent_kernel( offs_k = ki * BLOCK_SIZE_K - a = tl._experimental_descriptor_load( - a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype - ) + a = a_desc_ptr.load([offs_am, offs_k]) + b = b_desc_ptr.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) if ki == k_tiles - 1: c = accumulator.to(dtype) - tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + c_desc_ptr.store([offs_am, offs_bn], c) # calculate progress and send signals to corresponding ranks scatter_start = offs_am // M_per_rank @@ -194,29 +192,16 @@ def gemm_producer_w_progress( bT = b.T - desc_a = _create_2d_tma_descriptor( - a.data_ptr(), - M, - K, - configs["BLOCK_SIZE_M"], - configs["BLOCK_SIZE_K"], - a.element_size(), + desc_a = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + a, [configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_K"]] ) - desc_bt = _create_2d_tma_descriptor( - bT.data_ptr(), - N, - K, - configs["BLOCK_SIZE_N"], - configs["BLOCK_SIZE_K"], - bT.element_size(), + desc_bt = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + bT, + [configs["BLOCK_SIZE_N"], configs["BLOCK_SIZE_K"]], ) - desc_c = _create_2d_tma_descriptor( - gemm_out.data_ptr(), - M, - N, - configs["BLOCK_SIZE_M"], - configs["BLOCK_SIZE_N"], - gemm_out.element_size(), + desc_c = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + gemm_out, + [configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_N"]], ) configs["NUM_SMS"] = torch.cuda.get_device_properties( @@ -274,31 +259,20 @@ def _reduce_persistent_kernel( tile_id_m = tile_id // num_tiles_n tile_id_n = tile_id % num_tiles_n cur_rank = (RANK + 1) % WORLD_SIZE - accum = tl._experimental_descriptor_load( - in_desc_ptr, - [ - tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, - tile_id_n * BLOCK_SIZE_N, - ], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - tl.bfloat16, + accum = in_desc_ptr.load( + [tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, tile_id_n * BLOCK_SIZE_N] ) for i in range(1, WORLD_SIZE): cur_rank = (i + RANK + 1) % WORLD_SIZE - data = tl._experimental_descriptor_load( - in_desc_ptr, + data = in_desc_ptr.load( [ tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, tile_id_n * BLOCK_SIZE_N, - ], - [BLOCK_SIZE_M, BLOCK_SIZE_N], - tl.bfloat16, + ] ) accum += data - tl._experimental_descriptor_store( - out_desc_ptr, accum, [tile_id_m * BLOCK_SIZE_M, tile_id_n * BLOCK_SIZE_N] - ) + out_desc_ptr.store([tile_id_m * BLOCK_SIZE_M, tile_id_n * BLOCK_SIZE_N], accum) def reduce( @@ -312,22 +286,11 @@ def reduce( BLOCK_SIZE_M = 256 BLOCK_SIZE_N = 64 - - in_desc_ptr = _create_2d_tma_descriptor( - inp.data_ptr(), - M, - N, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - inp.element_size(), + in_desc_ptr = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + inp, [BLOCK_SIZE_M, BLOCK_SIZE_N] ) - out_desc_ptr = _create_2d_tma_descriptor( - output.data_ptr(), - M_per_rank, - N, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - output.element_size(), + out_desc_ptr = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + output, [BLOCK_SIZE_M, BLOCK_SIZE_N] ) grid = lambda META: ( # noqa: E731 @@ -359,6 +322,12 @@ def gemm_reduce_scatter_ce_persistent( M = a.shape[0] N = b.shape[1] + + # 1. Initialize NVSHMEM device library + # nvshmem_lib = nvshmem.enable_triton() + + + group = dist.group.WORLD if group is None else group gemm_out = torch.empty((M, N), dtype=a.dtype, device=a.device) symm_mem_hdl = symm_mem.get_symm_mem_workspace( diff --git a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused_scatter.py b/kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused_scatter.py new file mode 100644 index 0000000..c939edf --- /dev/null +++ b/kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused_scatter.py @@ -0,0 +1,317 @@ +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem + +import triton # @manual +import triton.language as tl # @manual + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + if "c_desc_ptr" in args: + bytes_per_elem = args["c_desc_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K) + return ret + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def _gemm_producer_persistent_kernel( + a_desc_ptr, + b_desc_ptr, + symm_mem_ptrs_ptr, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + NUM_SMS: tl.constexpr, +): + dtype = tl.float8e4nv if FP8_OUTPUT else tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + M_per_rank = M // WORLD_SIZE + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + pid_m_offset = (RANK + 1) * M_per_rank // BLOCK_SIZE_M + + for _ in range(k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Pivot tile_id so that M tiles are processed in communication order. + # This pivot preserves the prior swizzling. + pid_m = (pid_m + pid_m_offset) % num_pid_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + offs_k = ki * BLOCK_SIZE_K + + a = a_desc_ptr.load([offs_am, offs_k]) + b = b_desc_ptr.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + remote_rank = offs_am // M_per_rank + remote_buffer_ptr_int64 = tl.load(symm_mem_ptrs_ptr + remote_rank) + remote_buffer_ptr = remote_buffer_ptr_int64.to(tl.pointer_type(dtype)) + block_ptr = tl.make_block_ptr( + base=remote_buffer_ptr, # int64 base address + shape=(M, N), # full matrix shape + strides=(N, 1), # row-major + offsets=( + M_per_rank * RANK - M_per_rank * remote_rank + offs_am, + offs_bn, + ), # tile start + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), # tile size + order=(1, 0), # row-major + ) + tl.store(block_ptr, c) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def gemm_producer_w_progress( + a: torch.Tensor, + b: torch.Tensor, + symm_mem_ptrs_ptr, + configs: dict, + group: dist.ProcessGroup | None = None, +): + M, K = a.shape + Kb, N = b.shape + assert K == Kb, "Inner dimensions must match for matrix multiplication" + assert a.dtype == b.dtype, "Input dtypes must match" + + bT = b.T + + desc_a = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + a, [configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_K"]] + ) + desc_bt = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + bT, + [configs["BLOCK_SIZE_N"], configs["BLOCK_SIZE_K"]], + ) + + configs["NUM_SMS"] = torch.cuda.get_device_properties( + a.device + ).multi_processor_count + + grid = lambda META: ( # noqa: E731 + min( + configs["NUM_SMS"], + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + + group = dist.group.WORLD if group is None else group + + _gemm_producer_persistent_kernel[grid]( + desc_a, + desc_bt, + symm_mem_ptrs_ptr, + M, + N, + K, + BLOCK_SIZE_M=configs["BLOCK_SIZE_M"], + BLOCK_SIZE_N=configs["BLOCK_SIZE_N"], + BLOCK_SIZE_K=configs["BLOCK_SIZE_K"], + GROUP_SIZE_M=configs["GROUP_SIZE_M"], + RANK=configs["RANK"], + WORLD_SIZE=configs["WORLD_SIZE"], + FP8_OUTPUT=a.dtype == torch.float8_e4m3fn, + NUM_SMS=configs["NUM_SMS"], + num_stages=configs["num_stages"], + num_warps=configs["num_warps"], + ) + + +@triton.jit +def _reduce_persistent_kernel( + in_desc_ptr, # [M, N] + out_desc_ptr, # [M_per_rank, N] + M_per_rank, + N, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr = 256, + BLOCK_SIZE_N: tl.constexpr = 64, +): + pid = tl.program_id(axis=0) + num_pid = tl.num_programs(axis=0) + num_tiles_m = tl.cdiv(M_per_rank, BLOCK_SIZE_M) + num_tiles_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_tiles_m * num_tiles_n + for tile_id in range(pid, total_tiles, num_pid): + tile_id_m = tile_id // num_tiles_n + tile_id_n = tile_id % num_tiles_n + cur_rank = (RANK + 1) % WORLD_SIZE + accum = in_desc_ptr.load( + [tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, tile_id_n * BLOCK_SIZE_N] + ) + for i in range(1, WORLD_SIZE): + cur_rank = (i + RANK + 1) % WORLD_SIZE + data = in_desc_ptr.load( + [ + tile_id_m * BLOCK_SIZE_M + cur_rank * M_per_rank, + tile_id_n * BLOCK_SIZE_N, + ] + ) + accum += data + + out_desc_ptr.store([tile_id_m * BLOCK_SIZE_M, tile_id_n * BLOCK_SIZE_N], accum) + + +def reduce( + inp: torch.Tensor, # scatter_out with shape [M, N] + output: torch.Tensor, # [M_per_rank, N] + configs: dict, +): + M, N = inp.shape + M_per_rank = M // configs["WORLD_SIZE"] + assert output.shape[0] == M_per_rank and M % configs["WORLD_SIZE"] == 0 + + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 64 + in_desc_ptr = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + inp, [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + out_desc_ptr = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + output, [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + grid = lambda META: ( # noqa: E731 + triton.cdiv(M_per_rank, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + _reduce_persistent_kernel[grid]( + in_desc_ptr, + out_desc_ptr, + M_per_rank, + N, + RANK=configs["RANK"], + WORLD_SIZE=configs["WORLD_SIZE"], + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + num_warps=4, + ) + + return output + + +def triton_fused_matmul_reduce_scatter( + a: torch.Tensor, + b: torch.Tensor, + group: dist.ProcessGroup | None = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fused GEMM + Reduce-Scatter with overlapped GEMM and Scatter operation. + Computes C = A @ B on each rank, then performs reduce-scatter to sum results + and scatter them along the M dimension. + Args: + a: Input matrix A of shape (M, K) + b: Input matrix B of shape (K, N) + Returns: + Output matrix of shape (M / world_size, N) containing the reduce-scattered result. + """ + assert ( + a.shape[1] == b.shape[0] + ), "Inner dimensions must match for matrix multiplication" + + M, N = a.shape[0], b.shape[1] + + # Use the global process group if no specific group is provided, otherwise use the given group + group = dist.group.WORLD if group is None else group + # Get the total number of processes/GPUs in the distributed group + world_size = dist.get_world_size(group) + # Get the current process's rank (ID) within the distributed group (0 to world_size-1) + rank = dist.get_rank(group) + + assert ( + M % world_size == 0 + ), f"M dimension ({M}) must be divisible by world_size ({world_size})" + + # Create symmetric buffer for GEMM output + symm_mem_hdl = symm_mem.get_symm_mem_workspace( + group.group_name, min_size=M * N * a.element_size() + ) + + # Ensure the matrix can be evenly divided among all processes + assert M % world_size == 0 + + # Create output tensor for the scatter result + M_per_rank = M // world_size + output = torch.empty((M_per_rank, N), dtype=a.dtype, device=a.device) + + # configurations for GEMM heurisitcs etc + configs = { + "BLOCK_SIZE_M": kwargs.get("block_size_m", 64), + "BLOCK_SIZE_N": kwargs.get("block_size_n", 256), + "BLOCK_SIZE_K": kwargs.get("block_size_k", 64), + "GROUP_SIZE_M": kwargs.get("group_size_m", 8), + "num_stages": kwargs.get("num_stages", 3), + "num_warps": kwargs.get("num_warps", 8), + } + configs["RANK"] = rank + configs["WORLD_SIZE"] = world_size + + assert ( + (M / world_size) % configs["BLOCK_SIZE_M"] == 0 + ), f"M_per_rank dimension ({M / world_size}) must be divisible by BLOCK_SIZE_M ({configs["BLOCK_SIZE_M"]})" + + # Create an array of raw data pointers to the symmetric memory buffers on each rank + symm_mem_ptrs = [] + scatter_out = None + for rank in range(world_size): + if rank == symm_mem_hdl.rank: + scatter_out = symm_mem_hdl.get_buffer(rank, [M, N], a.dtype, 0) + symm_mem_ptrs.append(scatter_out.data_ptr()) + continue + symm_mem_ptrs.append( + symm_mem_hdl.get_buffer(rank, [M, N], a.dtype, 0).data_ptr() + ) + symm_mem_ptr = torch.tensor(symm_mem_ptrs, dtype=torch.int64, device=a.device) + + gemm_producer_w_progress(a, b, symm_mem_ptr, configs) + symm_mem_hdl.barrier() + + # Communication is now fused into the GEMM kernel, no separate copy engine needed + + reduce(scatter_out, output, configs) + + return output diff --git a/test/test_gemm_reduce_scatter.py b/test/test_gemm_reduce_scatter.py index 55ab6ff..7e41bae 100644 --- a/test/test_gemm_reduce_scatter.py +++ b/test/test_gemm_reduce_scatter.py @@ -19,6 +19,7 @@ from kraken.reduce_scatter_fusion import ( gemm_reduce_scatter, gemm_reduce_scatter_ce_persistent, + triton_fused_matmul_reduce_scatter, ) @@ -112,6 +113,26 @@ def test_gemm_reduce_scatter_ce_persistent(self): torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2) dist.destroy_process_group() + @skip_if_lt_x_gpu(4) + def test_gemm_reduce_scatter_fused_scatter(self): + self._init_process() + M, N, K = 8192, 4096, 14336 + a = torch.randn((M, K), dtype=torch.bfloat16, device=self.device) + b = torch.randn((N, K), dtype=torch.bfloat16, device=self.device).t() + + result = triton_fused_matmul_reduce_scatter(a, b) + + gemm_out = torch.matmul(a, b) + expected = torch.empty( + (M // self.world_size, N), device="cuda", dtype=torch.bfloat16 + ) + torch.distributed.reduce_scatter_tensor( + expected, gemm_out, group=dist.group.WORLD + ) + + torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2) + dist.destroy_process_group() + if __name__ == "__main__": run_tests()