-
Notifications
You must be signed in to change notification settings - Fork 33
An example where gemm and all-scatter are independent #232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
d8d2377
Add an example for separate gemm/all-scatter
neoblizz f18ed99
Apply Ruff auto-fixes
github-actions[bot] c18aded
Add validation for separate GEMM and all-scatter operations in exampl…
Copilot 1664585
[WIP] Add validation for example 20 (#236)
Copilot 5aee08f
Delete unused.
neoblizz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,347 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-License-Identifier: MIT | ||
| # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.multiprocessing as mp | ||
| import triton | ||
| import random | ||
| import sys | ||
| import os | ||
| import argparse | ||
| import json | ||
|
|
||
| from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set | ||
| from examples.common.validation import validate_gemm, validate_all_scatter | ||
|
|
||
| import iris | ||
|
|
||
| from matmul_wrapper import matmul | ||
| from gemm_all_scatter_bulk_synchronous import persistent_all_scatter | ||
|
|
||
| torch.manual_seed(123) | ||
| random.seed(123) | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Parse matrix dimensions and configuration.", | ||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
| ) | ||
| parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A (GEMM)") | ||
| parser.add_argument("-n", type=int, default=4608, help="Number of columns in matrix B (GEMM)") | ||
| parser.add_argument("-k", type=int, default=36864, help="Common dimension between matrices A and B (GEMM)") | ||
| parser.add_argument( | ||
| "--m_comm", type=int, default=None, help="Number of rows for communication tensor (defaults to m)" | ||
| ) | ||
| parser.add_argument( | ||
| "--n_comm", type=int, default=None, help="Total number of columns for communication tensor (defaults to n)" | ||
| ) | ||
| parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") | ||
| parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") | ||
| parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") | ||
| parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") | ||
| parser.add_argument( | ||
| "--datatype", | ||
| type=str, | ||
| default="fp16", | ||
| choices=["fp16", "fp32", "int8", "bf16"], | ||
| help="Datatype of computation", | ||
| ) | ||
| parser.add_argument( | ||
| "--output_file", | ||
| type=str, | ||
| default="log.json", | ||
| help="Output file", | ||
| ) | ||
| parser.add_argument("--BLK_M", type=int, default=256, help="Block size M") | ||
| parser.add_argument("--BLK_N", type=int, default=64, help="Block size N") | ||
| parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") | ||
| parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") | ||
| parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") | ||
| parser.add_argument( | ||
| "--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm" | ||
| ) | ||
| parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel") | ||
| parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") | ||
|
|
||
| return vars(parser.parse_args()) | ||
|
|
||
|
|
||
| def _worker(local_rank: int, world_size: int, init_url: str, args: dict): | ||
| """Worker function for PyTorch distributed execution.""" | ||
| backend = "nccl" if torch.cuda.is_available() else "gloo" | ||
| dist.init_process_group( | ||
| backend=backend, | ||
| init_method=init_url, | ||
| world_size=world_size, | ||
| rank=local_rank, | ||
| device_id=torch.device(f"cuda:{local_rank}"), | ||
| ) | ||
| shmem = iris.iris(args["heap_size"]) | ||
| rank = shmem.get_rank() | ||
| world_size = shmem.get_num_ranks() | ||
| cu_count = shmem.get_cu_count() | ||
|
|
||
| # GEMM | ||
| datatype = torch.float32 | ||
| if args["datatype"] == "fp16": | ||
| datatype = torch.float16 | ||
| elif args["datatype"] == "fp32": | ||
| datatype = torch.float32 | ||
| elif args["datatype"] == "int8": | ||
| datatype = torch.int8 | ||
| elif args["datatype"] == "bf16": | ||
| datatype = torch.bfloat16 | ||
| else: | ||
| print("Unknown datatype.") | ||
| exit(1) | ||
|
|
||
| assert args["n"] % world_size == 0, f"N ({args['n']}) must be divisible by world size ({world_size})." | ||
| assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})." | ||
|
|
||
| # Set default values for communication dimensions if not provided | ||
| if args["m_comm"] is None: | ||
| args["m_comm"] = args["m"] | ||
| if args["n_comm"] is None: | ||
| args["n_comm"] = args["n"] | ||
|
|
||
| # Validate communication dimensions | ||
| assert args["n_comm"] % world_size == 0, ( | ||
| f"Communication N ({args['n_comm']}) must be divisible by world size ({world_size})" | ||
| ) | ||
|
|
||
| # Calculate per-rank communication columns | ||
| n_comm_local = args["n_comm"] // world_size | ||
|
|
||
| A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype) | ||
| B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T | ||
neoblizz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| json_writer = JSONWriter(args["output_file"]) | ||
| json_writer.add_field("world_size", world_size) | ||
|
|
||
| local_B = B | ||
| local_A = A | ||
|
|
||
| for key, value in args.items(): | ||
| json_writer.add_field(key, value) | ||
|
|
||
| C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) | ||
| # Create global communication tensor that will hold scattered results from all ranks | ||
| C_comm_global = shmem.zeros((args["m_comm"], args["n_comm"]), device="cuda", dtype=datatype) | ||
| # Create local communication tensor with rank-specific data | ||
| C_comm = shmem.full((args["m_comm"], n_comm_local), rank + 1.0, device="cuda", dtype=datatype) | ||
| # Initialize this rank's portion in the global tensor with the local data | ||
| C_comm_global[:, rank * n_comm_local : (rank + 1) * n_comm_local] = C_comm | ||
|
|
||
| total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) | ||
| total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) | ||
| total_tiles = total_blocks_M * total_blocks_N | ||
|
|
||
| bias = None | ||
|
|
||
| num_xcds = iris.hip.get_num_xcc() | ||
|
|
||
| # This is one after another. | ||
| gemm_stream = torch.cuda.Stream() | ||
| comm_stream = torch.cuda.Stream() | ||
|
|
||
| json_writer.add_field("gemm_sms", args["gemm_sms"]) | ||
| json_writer.add_field("comm_sms", args["comm_sms"]) | ||
|
|
||
| kernel_timing = { | ||
| "gemm": { | ||
| "start_event": torch.cuda.Event(enable_timing=True), | ||
| "end_event": torch.cuda.Event(enable_timing=True), | ||
| "ms": 0, | ||
| "experiments": 0, | ||
| }, | ||
| "communication": { | ||
| "start_event": torch.cuda.Event(enable_timing=True), | ||
| "end_event": torch.cuda.Event(enable_timing=True), | ||
| "ms": 0, | ||
| "experiments": 0, | ||
| }, | ||
| } | ||
|
|
||
| # Allocate Timestamps | ||
| timestamps = Timestamps(num_tiles=total_tiles) | ||
|
|
||
| def run_experiment(): | ||
| nonlocal C | ||
| nonlocal C_comm | ||
| nonlocal C_comm_global | ||
| nonlocal kernel_timing | ||
|
|
||
| shmem.barrier() | ||
|
|
||
| if args["trace_tiles"]: | ||
| timestamps.reset() | ||
| shmem.barrier() | ||
|
|
||
| torch.cuda.nvtx.range_push("GEMM + Communication") | ||
| torch.cuda.nvtx.range_push("GEMM") | ||
| with torch.cuda.stream(gemm_stream): | ||
| kernel_timing["gemm"]["start_event"].record() | ||
| C = matmul.apply( | ||
| local_A, | ||
| local_B, | ||
| C, | ||
| bias, | ||
| rank, | ||
| world_size, | ||
| args["gemm_sms"], | ||
| args["BLK_M"], | ||
| args["BLK_N"], | ||
| args["BLK_K"], | ||
| args["gsize_m"], | ||
| shmem.get_heap_bases(), | ||
| "gfx942", | ||
| args["trace_tiles"], | ||
| timestamps.mm_begin_timestamp, | ||
| timestamps.mm_end_timestamp, | ||
| ) | ||
| kernel_timing["gemm"]["end_event"].record() | ||
| kernel_timing["gemm"]["experiments"] += 1 | ||
|
|
||
| torch.cuda.nvtx.range_pop() | ||
| torch.cuda.nvtx.range_push("Communication") | ||
| with torch.cuda.stream(comm_stream): | ||
| kernel_timing["communication"]["start_event"].record() | ||
| persistent_all_scatter[(args["comm_sms"],)]( | ||
| C_comm_global, | ||
| args["m_comm"], | ||
| n_comm_local, | ||
| C_comm_global.stride(0), | ||
| C_comm_global.stride(1), | ||
| args["BLK_M"], | ||
| args["BLK_N"], | ||
| args["gsize_m"], | ||
| args["comm_sms"], | ||
| num_xcds, | ||
| shmem.get_heap_bases(), | ||
| rank, | ||
| world_size, | ||
| args["trace_tiles"], | ||
| timestamps.mm_begin_timestamp, | ||
| timestamps.mm_end_timestamp, | ||
| ) | ||
| kernel_timing["communication"]["end_event"].record() | ||
| kernel_timing["communication"]["experiments"] += 1 | ||
| torch.cuda.nvtx.range_pop() | ||
| shmem.barrier() | ||
|
|
||
| for k in ["gemm", "communication"]: | ||
| ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) | ||
| kernel_timing[k]["ms"] += ms | ||
|
|
||
| torch.cuda.nvtx.range_pop() | ||
|
|
||
| # Synchronize across all GPUs | ||
| shmem.barrier() | ||
|
|
||
| # Warmup | ||
| run_experiment() | ||
|
|
||
| shmem.barrier() | ||
|
|
||
| for k in ["gemm", "communication"]: | ||
| kernel_timing[k]["ms"] = 0 | ||
| kernel_timing[k]["experiments"] = 0 | ||
|
|
||
| if args["validate"]: | ||
| # Ensure all GPU kernels have completed before validation | ||
| torch.cuda.synchronize() | ||
| shmem.barrier() | ||
|
|
||
| shmem.info("Validating...") | ||
| matmul.set_debug(True) | ||
|
|
||
| # Validate GEMM result | ||
| shmem.info("Validating GEMM operation...") | ||
| success_gemm = validate_gemm(A, B, C, shmem) | ||
| passed_str = "passed" if success_gemm else "failed" | ||
| shmem.info(f"GEMM validation {passed_str}.") | ||
|
|
||
| # Wait for all to finish GEMM validation | ||
| shmem.barrier() | ||
|
|
||
| # Validate all-scatter result | ||
| shmem.info("Validating all-scatter operation...") | ||
| success_comm = validate_all_scatter(C_comm, C_comm_global, shmem) | ||
| passed_str = "passed" if success_comm else "failed" | ||
| shmem.info(f"All-scatter validation {passed_str}.") | ||
|
|
||
| # Overall success | ||
| success = success_gemm and success_comm | ||
| overall_str = "passed" if success else "failed" | ||
| shmem.info(f"Overall validation {overall_str}.") | ||
|
|
||
| # Wait for all to finish validation | ||
| shmem.barrier() | ||
|
|
||
| json_writer.add_field("success", success) | ||
| json_writer.add_field("success_gemm", success_gemm) | ||
| json_writer.add_field("success_comm", success_comm) | ||
|
|
||
| if not is_triton_interpret_set(): | ||
| gemm_registers = matmul.get_matmul_registers() | ||
| gemm_spills = matmul.get_matmul_spills() | ||
|
|
||
| json_writer.add_field("gemm_registers", gemm_registers) | ||
| json_writer.add_field("gemm_spills", gemm_spills) | ||
|
|
||
| shmem.info("Validation completed") | ||
|
|
||
| if args["benchmark"]: | ||
| matmul.set_debug(False) | ||
| shmem.info("Benchmarking...") | ||
| perf = lambda ms: 2 * args["m"] * args["n"] * args["k"] * 1e-12 / (ms * 1e-3) | ||
| triton_ms = iris.do_bench(run_experiment, shmem.barrier) | ||
| triton_tflops = perf(triton_ms) | ||
| algo_string = "all_scatter" | ||
| shmem.info( | ||
| f"tile matmul + {algo_string} (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops" | ||
| ) | ||
|
|
||
| json_writer.add_field("tflops", triton_tflops) | ||
| json_writer.add_field("total_ms", triton_ms) | ||
|
|
||
| for k in ["gemm", "communication"]: | ||
| json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) | ||
| json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) | ||
|
|
||
| # Wait for all to finish benchmarking | ||
| shmem.barrier() | ||
|
|
||
| if rank == 0: | ||
| json_writer.flush() | ||
| json_writer.display() | ||
|
|
||
| if args["trace_tiles"] and rank == 0: | ||
| gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 | ||
| algo_string = "all_scatter" | ||
| filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json" | ||
| timestamps.to_json(filename, gpu_freq) | ||
|
|
||
| shmem.barrier() | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| def main(): | ||
| args = parse_args() | ||
|
|
||
| num_ranks = args["num_ranks"] | ||
|
|
||
| init_url = "tcp://127.0.0.1:29500" | ||
| mp.spawn( | ||
| fn=_worker, | ||
| args=(num_ranks, init_url, args), | ||
| nprocs=num_ranks, | ||
| join=True, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.