|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-License-Identifier: MIT |
| 3 | +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.distributed as dist |
| 7 | +import torch.multiprocessing as mp |
| 8 | +import triton |
| 9 | +import random |
| 10 | +import argparse |
| 11 | +import math |
| 12 | + |
| 13 | +from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set |
| 14 | +from examples.common.validation import validate_reduce_scatter |
| 15 | + |
| 16 | +import iris |
| 17 | +from matmul_wrapper import MatMulReduceScatterWgSpecialized |
| 18 | + |
| 19 | +torch.manual_seed(0) |
| 20 | +random.seed(0) |
| 21 | + |
| 22 | + |
| 23 | +def parse_args(): |
| 24 | + parser = argparse.ArgumentParser( |
| 25 | + description="GEMM + ReduceScatter Benchmark with Workgroup Specialization", |
| 26 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| 27 | + ) |
| 28 | + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A (M)") |
| 29 | + parser.add_argument("-n", type=int, default=4096, help="Number of columns in matrix B (N)") |
| 30 | + parser.add_argument("-k", type=int, default=12288, help="Common dimension (K), will be split across ranks") |
| 31 | + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") |
| 32 | + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") |
| 33 | + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") |
| 34 | + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") |
| 35 | + parser.add_argument( |
| 36 | + "--datatype", |
| 37 | + type=str, |
| 38 | + default="fp16", |
| 39 | + choices=["fp16", "fp32", "bf16"], |
| 40 | + help="Datatype of computation", |
| 41 | + ) |
| 42 | + parser.add_argument( |
| 43 | + "--output_file", |
| 44 | + type=str, |
| 45 | + default="log.json", |
| 46 | + help="Output file", |
| 47 | + ) |
| 48 | + parser.add_argument("--BLK_M", type=int, default=128, help="Block size M") |
| 49 | + parser.add_argument("--BLK_N", type=int, default=256, help="Block size N") |
| 50 | + parser.add_argument("--BLK_K", type=int, default=32, help="Block size K") |
| 51 | + parser.add_argument("--gsize_m", type=int, default=1, help="L2-cache locality swizzle parameter") |
| 52 | + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") |
| 53 | + parser.add_argument( |
| 54 | + "--num_sms", |
| 55 | + type=int, |
| 56 | + default=None, |
| 57 | + help="Number of total SMs (default: auto-detected)", |
| 58 | + ) |
| 59 | + parser.add_argument( |
| 60 | + "--gemm_sms", |
| 61 | + type=int, |
| 62 | + default=None, |
| 63 | + help="Number of SMs for GEMM (default: auto-detected as power of 2)", |
| 64 | + ) |
| 65 | + parser.add_argument("--num_stages", type=int, default=2, help="Number of stages") |
| 66 | + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") |
| 67 | + |
| 68 | + return vars(parser.parse_args()) |
| 69 | + |
| 70 | + |
| 71 | +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): |
| 72 | + """Worker function for PyTorch distributed execution.""" |
| 73 | + backend = "nccl" if torch.cuda.is_available() else "gloo" |
| 74 | + dist.init_process_group( |
| 75 | + backend=backend, |
| 76 | + init_method=init_url, |
| 77 | + world_size=world_size, |
| 78 | + rank=local_rank, |
| 79 | + device_id=torch.device(f"cuda:{local_rank}"), |
| 80 | + ) |
| 81 | + |
| 82 | + shmem = iris.iris(args["heap_size"]) |
| 83 | + rank = shmem.get_rank() |
| 84 | + world_size = shmem.get_num_ranks() |
| 85 | + |
| 86 | + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count |
| 87 | + if args["num_sms"] is None: |
| 88 | + args["num_sms"] = cu_count |
| 89 | + if args["gemm_sms"] is None: |
| 90 | + # Use next smaller power of 2 for GEMM SMs |
| 91 | + args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1 |
| 92 | + |
| 93 | + datatype = torch.float16 |
| 94 | + if args["datatype"] == "fp16": |
| 95 | + datatype = torch.float16 |
| 96 | + elif args["datatype"] == "fp32": |
| 97 | + datatype = torch.float32 |
| 98 | + elif args["datatype"] == "bf16": |
| 99 | + datatype = torch.bfloat16 |
| 100 | + else: |
| 101 | + print("Unknown datatype.") |
| 102 | + exit(1) |
| 103 | + |
| 104 | + M, N, K = args["m"], args["n"], args["k"] |
| 105 | + |
| 106 | + assert M % world_size == 0, f"M ({M}) must be divisible by world size ({world_size})" |
| 107 | + assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})" |
| 108 | + assert (M // world_size) % args["BLK_M"] == 0, ( |
| 109 | + f"M_per_rank ({M // world_size}) must be divisible by BLK_M ({args['BLK_M']})" |
| 110 | + ) |
| 111 | + |
| 112 | + local_K = K // world_size |
| 113 | + M_per_rank = M // world_size |
| 114 | + |
| 115 | + A_full = shmem.randn(M, K, device="cuda", dtype=datatype) |
| 116 | + B_full = shmem.randn(K, N, device="cuda", dtype=datatype) |
| 117 | + |
| 118 | + # Each rank gets a portion of K dimension as input |
| 119 | + local_A = A_full[:, rank * local_K : (rank + 1) * local_K].clone() |
| 120 | + local_B = B_full[rank * local_K : (rank + 1) * local_K, :].clone() |
| 121 | + |
| 122 | + json_writer = JSONWriter(args["output_file"]) |
| 123 | + json_writer.add_field("world_size", world_size) |
| 124 | + json_writer.add_field("M", M) |
| 125 | + json_writer.add_field("N", N) |
| 126 | + json_writer.add_field("K", K) |
| 127 | + json_writer.add_field("local_K", local_K) |
| 128 | + |
| 129 | + for key, value in args.items(): |
| 130 | + json_writer.add_field(key, value) |
| 131 | + |
| 132 | + local_buf = shmem.zeros((M, N), device="cuda", dtype=datatype) |
| 133 | + |
| 134 | + output_buf = shmem.zeros((M_per_rank, N), device="cuda", dtype=datatype) |
| 135 | + |
| 136 | + total_blocks_M = triton.cdiv(M, args["BLK_M"]) |
| 137 | + total_blocks_N = triton.cdiv(N, args["BLK_N"]) |
| 138 | + total_tiles = total_blocks_M * total_blocks_N |
| 139 | + |
| 140 | + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) |
| 141 | + |
| 142 | + gemm_stream = torch.cuda.Stream() |
| 143 | + |
| 144 | + json_writer.add_field("num_sms", args["num_sms"]) |
| 145 | + json_writer.add_field("gemm_sms", args["gemm_sms"]) |
| 146 | + |
| 147 | + kernel_timing = { |
| 148 | + "gemm_rs": { |
| 149 | + "start_event": torch.cuda.Event(enable_timing=True), |
| 150 | + "end_event": torch.cuda.Event(enable_timing=True), |
| 151 | + "ms": 0, |
| 152 | + "experiments": 0, |
| 153 | + }, |
| 154 | + } |
| 155 | + |
| 156 | + timestamps = Timestamps(num_tiles=total_tiles) |
| 157 | + |
| 158 | + def run_experiment(): |
| 159 | + nonlocal local_buf, output_buf |
| 160 | + |
| 161 | + local_buf.zero_() |
| 162 | + output_buf.zero_() |
| 163 | + locks.zero_() |
| 164 | + shmem.barrier() |
| 165 | + |
| 166 | + if args["trace_tiles"]: |
| 167 | + timestamps.reset() |
| 168 | + shmem.barrier() |
| 169 | + |
| 170 | + torch.cuda.nvtx.range_push("GEMM + ReduceScatter") |
| 171 | + with torch.cuda.stream(gemm_stream): |
| 172 | + kernel_timing["gemm_rs"]["start_event"].record() |
| 173 | + MatMulReduceScatterWgSpecialized.apply( |
| 174 | + local_A, |
| 175 | + local_B, |
| 176 | + local_buf, |
| 177 | + output_buf, |
| 178 | + locks, |
| 179 | + rank, |
| 180 | + world_size, |
| 181 | + args["gemm_sms"], |
| 182 | + args["num_sms"], |
| 183 | + args["BLK_M"], |
| 184 | + args["BLK_N"], |
| 185 | + args["BLK_K"], |
| 186 | + args["gsize_m"], |
| 187 | + args["num_stages"], |
| 188 | + shmem.get_heap_bases(), |
| 189 | + torch.cuda.get_device_properties(rank).name, |
| 190 | + args["trace_tiles"], |
| 191 | + timestamps.mm_begin_timestamp, |
| 192 | + timestamps.mm_end_timestamp, |
| 193 | + ) |
| 194 | + kernel_timing["gemm_rs"]["end_event"].record() |
| 195 | + kernel_timing["gemm_rs"]["experiments"] += 1 |
| 196 | + |
| 197 | + torch.cuda.nvtx.range_pop() |
| 198 | + shmem.barrier() |
| 199 | + |
| 200 | + for k in ["gemm_rs"]: |
| 201 | + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) |
| 202 | + kernel_timing[k]["ms"] += ms |
| 203 | + |
| 204 | + shmem.barrier() |
| 205 | + |
| 206 | + # Warmup |
| 207 | + run_experiment() |
| 208 | + |
| 209 | + shmem.barrier() |
| 210 | + |
| 211 | + for k in ["gemm_rs"]: |
| 212 | + kernel_timing[k]["ms"] = 0 |
| 213 | + kernel_timing[k]["experiments"] = 0 |
| 214 | + |
| 215 | + if args["validate"]: |
| 216 | + shmem.info("Validating...") |
| 217 | + MatMulReduceScatterWgSpecialized.set_debug(True) |
| 218 | + |
| 219 | + local_gemm = local_buf.clone() |
| 220 | + local_output = output_buf.clone() |
| 221 | + |
| 222 | + # Allow larger tolerance for fp16 due to accumulated rounding errors in atomic operations |
| 223 | + atol = 1.0 if datatype == torch.float16 else 0.5 |
| 224 | + |
| 225 | + tp_group = dist.new_group(ranks=list(range(world_size))) |
| 226 | + success = validate_reduce_scatter(local_gemm, local_output, shmem, tp_group, atol=atol) |
| 227 | + |
| 228 | + if success: |
| 229 | + shmem.info("✅ Triton and Torch match") |
| 230 | + else: |
| 231 | + shmem.info("❌ Triton and Torch differ") |
| 232 | + |
| 233 | + json_writer.add_field("success", success) |
| 234 | + |
| 235 | + if not is_triton_interpret_set(): |
| 236 | + gemm_registers = MatMulReduceScatterWgSpecialized.get_matmul_registers() |
| 237 | + gemm_spills = MatMulReduceScatterWgSpecialized.get_matmul_spills() |
| 238 | + json_writer.add_field("gemm_registers", gemm_registers) |
| 239 | + json_writer.add_field("gemm_spills", gemm_spills) |
| 240 | + |
| 241 | + shmem.barrier() |
| 242 | + shmem.info("Validation completed") |
| 243 | + |
| 244 | + if args["benchmark"]: |
| 245 | + MatMulReduceScatterWgSpecialized.set_debug(False) |
| 246 | + shmem.info("Benchmarking...") |
| 247 | + |
| 248 | + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) |
| 249 | + |
| 250 | + triton_ms = iris.do_bench(run_experiment, shmem.barrier) |
| 251 | + triton_tflops = perf(triton_ms) |
| 252 | + |
| 253 | + shmem.info(f"GEMM + ReduceScatter (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops") |
| 254 | + |
| 255 | + json_writer.add_field("tflops", triton_tflops) |
| 256 | + json_writer.add_field("total_ms", triton_ms) |
| 257 | + |
| 258 | + for k in ["gemm_rs"]: |
| 259 | + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) |
| 260 | + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) |
| 261 | + |
| 262 | + shmem.barrier() |
| 263 | + |
| 264 | + if rank == 0: |
| 265 | + json_writer.flush() |
| 266 | + json_writer.display() |
| 267 | + |
| 268 | + if args["trace_tiles"] and rank == 0: |
| 269 | + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 |
| 270 | + filename = f"gemm_tiles_reduce_scatter_trace_rank{rank}.json" |
| 271 | + timestamps.to_json(filename, gpu_freq) |
| 272 | + |
| 273 | + shmem.barrier() |
| 274 | + dist.destroy_process_group() |
| 275 | + |
| 276 | + |
| 277 | +def main(): |
| 278 | + args = parse_args() |
| 279 | + num_ranks = args["num_ranks"] |
| 280 | + |
| 281 | + init_url = "tcp://127.0.0.1:29500" |
| 282 | + mp.spawn( |
| 283 | + fn=_worker, |
| 284 | + args=(num_ranks, init_url, args), |
| 285 | + nprocs=num_ranks, |
| 286 | + join=True, |
| 287 | + ) |
| 288 | + |
| 289 | + |
| 290 | +if __name__ == "__main__": |
| 291 | + main() |
0 commit comments