diff --git a/benchmarks/distributed/__init__.py b/benchmarks/distributed/__init__.py index 6a9315fb..46e5136f 100644 --- a/benchmarks/distributed/__init__.py +++ b/benchmarks/distributed/__init__.py @@ -1,3 +1,4 @@ from __future__ import annotations +from .all_gather_matmul import AGMatmulBench as AGMatmulBench from .all_reduce import AllReduceBench as AllReduceBench diff --git a/benchmarks/distributed/all_gather_matmul.py b/benchmarks/distributed/all_gather_matmul.py new file mode 100644 index 00000000..ca418b2e --- /dev/null +++ b/benchmarks/distributed/all_gather_matmul.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import argparse + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem + +from .experiment_util import BenchmarkOperator +from .experiment_util import ExperimentConfig + +BUILDIN_SHAPES = [ + (256, 256, 256), + (384, 384, 384), + (512, 512, 512), + (640, 640, 640), + (768, 768, 768), + (896, 896, 896), + (1024, 1024, 1024), + (1152, 1152, 1152), + (1280, 1280, 1280), + (1408, 1408, 1408), + (1536, 1536, 1536), + (1664, 1664, 1664), + (1792, 1792, 1792), + (1920, 1920, 1920), + (2048, 2048, 2048), + (2176, 2176, 2176), + (2304, 2304, 2304), + (2432, 2432, 2432), + (2560, 2560, 2560), + (2688, 2688, 2688), + (2816, 2816, 2816), + (2944, 2944, 2944), + (3072, 3072, 3072), + (3200, 3200, 3200), + (3328, 3328, 3328), + (3456, 3456, 3456), + (3584, 3584, 3584), + (3712, 3712, 3712), + (3840, 3840, 3840), + (3968, 3968, 3968), + (4096, 4096, 4096), +] + + +class AGMatmulBench(BenchmarkOperator): + def gen_configs(self, args: argparse.Namespace) -> list[ExperimentConfig]: + all_configs = [] + for sz in args.shape: + all_configs.append( + ExperimentConfig( + shape=sz, + dtype=args.dtype, + backends=args.backend, + device=self.device, + ) + ) + + return all_configs + + def gen_inputs(self, config: ExperimentConfig) -> tuple: + M, N, K = config.shape + a = symm_mem.empty( + (M, K), + dtype=config.dtype, + device=config.device, + ) + b = ( + torch.randn((K, N), device=config.device, dtype=config.dtype) + .T.contiguous() + .T + ) + assert dist.group.WORLD is not None + symm_mem.rendezvous(a, dist.group.WORLD.group_name) + return (a, b) + + def additional_parser_args( + self, parser: argparse.ArgumentParser + ) -> argparse.ArgumentParser: + def matmul_shape_type(s: str) -> tuple[int, int, int]: + try: + M, N, K = map(int, s.split(",")) + return M, N, K + except Exception as e: + raise argparse.ArgumentTypeError( + "Matmul shape must be M, N, K. (M, K) @ (K, N) -> (M, N)" + ) from e + + parser.add_argument( + "--shape", + type=matmul_shape_type, + nargs="+", + default=BUILDIN_SHAPES, + help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", + ) + return parser + + def __init__(self) -> None: + self.op_name = "ag_matmul" + self.baseline = "nccl" + super().__init__() + + def nccl_mem_ag_mm( + a_shared: torch.Tensor, b: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + from torch.distributed._functional_collectives import all_gather_tensor + + a_gathered = all_gather_tensor(a_shared, 0, "0") + return a_gathered, torch.matmul(a_gathered, b) + + def torch_symm_mem_ag_mm( + a_shared: torch.Tensor, b: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert dist.group.WORLD is not None + a_gathered, c = torch.ops.symm_mem.fused_all_gather_matmul( + a_shared, [b], gather_dim=0, group_name=dist.group.WORLD.group_name + ) + return a_gathered, c[0] + + assert dist.group.WORLD is not None + + AG_MATMUL_DICT = { + "nccl": nccl_mem_ag_mm, + "torch_symm_mem": torch_symm_mem_ag_mm, + "helion": ("examples.all_gather_matmul", "helion_all_gather_matmul"), + "kraken": ("kraken.all_gather", "all_gather_matmul"), + } + self.backend_dict = AG_MATMUL_DICT diff --git a/benchmarks/distributed/experiment_util.py b/benchmarks/distributed/experiment_util.py index d362fec4..8ef8af03 100644 --- a/benchmarks/distributed/experiment_util.py +++ b/benchmarks/distributed/experiment_util.py @@ -37,7 +37,12 @@ def clone_symm_mem_tensor(tensor: torch.Tensor) -> torch.Tensor: device=tensor.device, ) assert dist.group.WORLD is not None - symm_mem.rendezvous(symm_mem_tensor, dist.group.WORLD.group_name) + try: + symm_mem.rendezvous(symm_mem_tensor, dist.group.WORLD.group_name) + except RuntimeError as e: + raise RuntimeError( + f"Failed to rendezvous tensor symmetric memory tensor of shape {tensor.shape}. " + ) from e symm_mem_tensor.copy_(tensor) return symm_mem_tensor @@ -68,7 +73,7 @@ class ExperimentConfig: device: Target device for the experiment, defaults to None (auto-detected) """ - shape: tuple[int] + shape: tuple[int, ...] dtype: torch.dtype backends: list[str] device: torch.device | None = None @@ -145,7 +150,7 @@ class BenchmarkOperator: --nnodes 1 --nproc-per-node 8 \ --rdzv-backend c10d --rdzv-endpoint localhost:0 \ --no_python python3 \ -benchmarks/run_distributed.py +benchmarks/run_distributed.py """ experiments: list[Experiment] @@ -207,6 +212,12 @@ def _parse_args(self) -> argparse.Namespace: description=f"Run benchmark for {self.__name__}. " + self.help_str ) + parser.add_argument( + "op", + type=str, + help="Operator to benchmark. ", + ) + parser.add_argument( "--backend", type=str, @@ -229,6 +240,8 @@ def _parse_args(self) -> argparse.Namespace: self.args = parser.parse_args() self.args.dtype = getattr(torch, self.args.dtype) + assert self.args.op == self.op_name + return self.args def __init__(self) -> None: @@ -244,7 +257,6 @@ def __init__(self) -> None: self.device = torch.device(f"cuda:{self.local_rank}") torch.cuda.set_device(self.device) - dist.init_process_group("nccl") torch.manual_seed(42 + self.local_rank) self.experiments = [] @@ -405,35 +417,43 @@ def get_results(self, metric: str = "speedup") -> defaultdict | None: def _run_experiment(self, config: ExperimentConfig) -> dict[str, float]: if self.baseline not in config.backends: - backends = config.backends.append(self.baseline) + backends = [*config.backends, self.baseline] else: backends = config.backends gloden_inp = self.gen_inputs(config) - inputs = {backend: clone_inputs(gloden_inp) for backend in backends} # pyright: ignore[reportOptionalIterable] gloden_fn = self.fn_dict[self.baseline] assert gloden_fn is not None + inp_og = clone_inputs(gloden_inp) gloden_o = gloden_fn(*gloden_inp) results = {} - for backend in backends: # pyright: ignore[reportOptionalIterable] + for backend in backends: fn = self.fn_dict[backend] if fn is None: results[backend] = float("nan") continue - inp = inputs[backend] + inp = clone_inputs(inp_og) target_fn = functools.partial(fn, *inp) try: test_o = target_fn() except RuntimeError: results[backend] = float("nan") continue + except AssertionError: + results[backend] = float("nan") + continue torch.testing.assert_close(test_o, gloden_o, atol=1e-1, rtol=1e-1) results[backend] = benchmark_distributed( target_fn, profile_ranks=[self.MASTER_RANK] ) + del test_o + del inp + + del gloden_inp + del gloden_o return results diff --git a/benchmarks/run_distributed.py b/benchmarks/run_distributed.py index 95d0b383..e5df9513 100644 --- a/benchmarks/run_distributed.py +++ b/benchmarks/run_distributed.py @@ -1,13 +1,47 @@ from __future__ import annotations -from benchmarks.distributed import AllReduceBench as AllReduceBenchmark +import sys + +from benchmarks.distributed import AGMatmulBench as AGMatmulBench +from benchmarks.distributed import AllReduceBench as AllReduceBench import torch.distributed as dist +OP_BENCH = { + "allreduce": AllReduceBench, + "ag_matmul": AGMatmulBench, +} + def main() -> None: - bench = AllReduceBenchmark() - bench.run() - bench.print_results(metric="time_us") + try: + dist.init_process_group("nccl") + except ValueError: + print(""" +Failed to initialize process group. Are you running with torchrun? +run distributed benchmark with: +torchrun \ +--nnodes 1 --nproc-per-node 8 \ +--rdzv-backend c10d --rdzv-endpoint localhost:0 \ +--no_python python3 \ +benchmarks/run_distributed.py +""") + sys.exit(1) + + if len(sys.argv) < 2: + print("Usage: python3 benchmarks/run_distributed.py ") + print(f"Available ops: {OP_BENCH.keys()}") + sys.exit(1) + + op = sys.argv[1] + + if op not in OP_BENCH: + print(f"Unknown op: {op}") + print(f"value ops: {OP_BENCH.keys()}") + sys.exit(1) + + op_bench = OP_BENCH[op]() + op_bench.run() + op_bench.print_results(metric="time_us") dist.destroy_process_group() diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py index 70661ea0..202763e3 100644 --- a/examples/all_gather_matmul.py +++ b/examples/all_gather_matmul.py @@ -59,7 +59,7 @@ def copy_engine_all_gather_w_progress( backend_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(backend_stream): for step in range(world_size): - src_rank = (rank + step + 1) % world_size + src_rank = (rank + step) % world_size for split_id in range(splits_per_rank): src_buf = symm_mem_hdl.get_buffer( src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id @@ -81,7 +81,9 @@ def copy_engine_all_gather_w_progress( block_sizes=[128, 256, 64], num_warps=8, num_stages=3, - indexing="block_ptr", + indexing="tensor_descriptor", + pid_type="persistent_interleaved", + l2_groupings=[4], ), static_shapes=True, ) @@ -90,7 +92,7 @@ def helion_matmul_w_progress( a_shared: torch.Tensor, b: torch.Tensor, progress: torch.Tensor, - SPLITS_PER_RANK: int, + SPLITS_PER_RANK: hl.constexpr, RANK: int, ) -> torch.Tensor: """ @@ -114,16 +116,19 @@ def helion_matmul_w_progress( M_per_rank = a_shared.size(0) for tile_m, tile_n in hl.tile([M, N]): acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - hl.wait( - progress, - [ - tile_m.begin // (M_per_rank // SPLITS_PER_RANK), - ], - signal=1, - ) + # TODO(joydddd): natively support starting range from non_zero index. + comm_block_id = ((tile_m.begin + RANK * M_per_rank) % M) // ( + M_per_rank // SPLITS_PER_RANK + ) # pyright: ignore[reportOperatorIssue] + hl.wait(progress, [comm_block_id], signal=1) for tile_k in hl.tile(K): - acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) - out[tile_m, tile_n] = acc + # TODO(joydddd): use a_shared and skip barrier when data is available on local rank. + acc = torch.addmm( + acc, + a[(tile_m.index + RANK * M_per_rank) % M, tile_k], + b[tile_k, tile_n], + ) + out[(tile_m.index + RANK * M_per_rank) % M, tile_n] = acc return out