Skip to content

Commit 9352d1a

Browse files
GEMM + ReduceScatter with Workgroup Specialization Example (#317)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent bf93968 commit 9352d1a

File tree

4 files changed

+700
-0
lines changed

4 files changed

+700
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: MIT
3+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
4+
5+
import torch
6+
import torch.distributed as dist
7+
import torch.multiprocessing as mp
8+
import triton
9+
import random
10+
import argparse
11+
import math
12+
13+
from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set
14+
from examples.common.validation import validate_reduce_scatter
15+
16+
import iris
17+
from matmul_wrapper import MatMulReduceScatterWgSpecialized
18+
19+
torch.manual_seed(0)
20+
random.seed(0)
21+
22+
23+
def parse_args():
24+
parser = argparse.ArgumentParser(
25+
description="GEMM + ReduceScatter Benchmark with Workgroup Specialization",
26+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
27+
)
28+
parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A (M)")
29+
parser.add_argument("-n", type=int, default=4096, help="Number of columns in matrix B (N)")
30+
parser.add_argument("-k", type=int, default=12288, help="Common dimension (K), will be split across ranks")
31+
parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode")
32+
parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode")
33+
parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode")
34+
parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode")
35+
parser.add_argument(
36+
"--datatype",
37+
type=str,
38+
default="fp16",
39+
choices=["fp16", "fp32", "bf16"],
40+
help="Datatype of computation",
41+
)
42+
parser.add_argument(
43+
"--output_file",
44+
type=str,
45+
default="log.json",
46+
help="Output file",
47+
)
48+
parser.add_argument("--BLK_M", type=int, default=128, help="Block size M")
49+
parser.add_argument("--BLK_N", type=int, default=256, help="Block size N")
50+
parser.add_argument("--BLK_K", type=int, default=32, help="Block size K")
51+
parser.add_argument("--gsize_m", type=int, default=1, help="L2-cache locality swizzle parameter")
52+
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
53+
parser.add_argument(
54+
"--num_sms",
55+
type=int,
56+
default=None,
57+
help="Number of total SMs (default: auto-detected)",
58+
)
59+
parser.add_argument(
60+
"--gemm_sms",
61+
type=int,
62+
default=None,
63+
help="Number of SMs for GEMM (default: auto-detected as power of 2)",
64+
)
65+
parser.add_argument("--num_stages", type=int, default=2, help="Number of stages")
66+
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")
67+
68+
return vars(parser.parse_args())
69+
70+
71+
def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
72+
"""Worker function for PyTorch distributed execution."""
73+
backend = "nccl" if torch.cuda.is_available() else "gloo"
74+
dist.init_process_group(
75+
backend=backend,
76+
init_method=init_url,
77+
world_size=world_size,
78+
rank=local_rank,
79+
device_id=torch.device(f"cuda:{local_rank}"),
80+
)
81+
82+
shmem = iris.iris(args["heap_size"])
83+
rank = shmem.get_rank()
84+
world_size = shmem.get_num_ranks()
85+
86+
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
87+
if args["num_sms"] is None:
88+
args["num_sms"] = cu_count
89+
if args["gemm_sms"] is None:
90+
# Use next smaller power of 2 for GEMM SMs
91+
args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1
92+
93+
datatype = torch.float16
94+
if args["datatype"] == "fp16":
95+
datatype = torch.float16
96+
elif args["datatype"] == "fp32":
97+
datatype = torch.float32
98+
elif args["datatype"] == "bf16":
99+
datatype = torch.bfloat16
100+
else:
101+
print("Unknown datatype.")
102+
exit(1)
103+
104+
M, N, K = args["m"], args["n"], args["k"]
105+
106+
assert M % world_size == 0, f"M ({M}) must be divisible by world size ({world_size})"
107+
assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})"
108+
assert (M // world_size) % args["BLK_M"] == 0, (
109+
f"M_per_rank ({M // world_size}) must be divisible by BLK_M ({args['BLK_M']})"
110+
)
111+
112+
local_K = K // world_size
113+
M_per_rank = M // world_size
114+
115+
A_full = shmem.randn(M, K, device="cuda", dtype=datatype)
116+
B_full = shmem.randn(K, N, device="cuda", dtype=datatype)
117+
118+
# Each rank gets a portion of K dimension as input
119+
local_A = A_full[:, rank * local_K : (rank + 1) * local_K].clone()
120+
local_B = B_full[rank * local_K : (rank + 1) * local_K, :].clone()
121+
122+
json_writer = JSONWriter(args["output_file"])
123+
json_writer.add_field("world_size", world_size)
124+
json_writer.add_field("M", M)
125+
json_writer.add_field("N", N)
126+
json_writer.add_field("K", K)
127+
json_writer.add_field("local_K", local_K)
128+
129+
for key, value in args.items():
130+
json_writer.add_field(key, value)
131+
132+
local_buf = shmem.zeros((M, N), device="cuda", dtype=datatype)
133+
134+
output_buf = shmem.zeros((M_per_rank, N), device="cuda", dtype=datatype)
135+
136+
total_blocks_M = triton.cdiv(M, args["BLK_M"])
137+
total_blocks_N = triton.cdiv(N, args["BLK_N"])
138+
total_tiles = total_blocks_M * total_blocks_N
139+
140+
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
141+
142+
gemm_stream = torch.cuda.Stream()
143+
144+
json_writer.add_field("num_sms", args["num_sms"])
145+
json_writer.add_field("gemm_sms", args["gemm_sms"])
146+
147+
kernel_timing = {
148+
"gemm_rs": {
149+
"start_event": torch.cuda.Event(enable_timing=True),
150+
"end_event": torch.cuda.Event(enable_timing=True),
151+
"ms": 0,
152+
"experiments": 0,
153+
},
154+
}
155+
156+
timestamps = Timestamps(num_tiles=total_tiles)
157+
158+
def run_experiment():
159+
nonlocal local_buf, output_buf
160+
161+
local_buf.zero_()
162+
output_buf.zero_()
163+
locks.zero_()
164+
shmem.barrier()
165+
166+
if args["trace_tiles"]:
167+
timestamps.reset()
168+
shmem.barrier()
169+
170+
torch.cuda.nvtx.range_push("GEMM + ReduceScatter")
171+
with torch.cuda.stream(gemm_stream):
172+
kernel_timing["gemm_rs"]["start_event"].record()
173+
MatMulReduceScatterWgSpecialized.apply(
174+
local_A,
175+
local_B,
176+
local_buf,
177+
output_buf,
178+
locks,
179+
rank,
180+
world_size,
181+
args["gemm_sms"],
182+
args["num_sms"],
183+
args["BLK_M"],
184+
args["BLK_N"],
185+
args["BLK_K"],
186+
args["gsize_m"],
187+
args["num_stages"],
188+
shmem.get_heap_bases(),
189+
torch.cuda.get_device_properties(rank).name,
190+
args["trace_tiles"],
191+
timestamps.mm_begin_timestamp,
192+
timestamps.mm_end_timestamp,
193+
)
194+
kernel_timing["gemm_rs"]["end_event"].record()
195+
kernel_timing["gemm_rs"]["experiments"] += 1
196+
197+
torch.cuda.nvtx.range_pop()
198+
shmem.barrier()
199+
200+
for k in ["gemm_rs"]:
201+
ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"])
202+
kernel_timing[k]["ms"] += ms
203+
204+
shmem.barrier()
205+
206+
# Warmup
207+
run_experiment()
208+
209+
shmem.barrier()
210+
211+
for k in ["gemm_rs"]:
212+
kernel_timing[k]["ms"] = 0
213+
kernel_timing[k]["experiments"] = 0
214+
215+
if args["validate"]:
216+
shmem.info("Validating...")
217+
MatMulReduceScatterWgSpecialized.set_debug(True)
218+
219+
local_gemm = local_buf.clone()
220+
local_output = output_buf.clone()
221+
222+
# Allow larger tolerance for fp16 due to accumulated rounding errors in atomic operations
223+
atol = 1.0 if datatype == torch.float16 else 0.5
224+
225+
tp_group = dist.new_group(ranks=list(range(world_size)))
226+
success = validate_reduce_scatter(local_gemm, local_output, shmem, tp_group, atol=atol)
227+
228+
if success:
229+
shmem.info("✅ Triton and Torch match")
230+
else:
231+
shmem.info("❌ Triton and Torch differ")
232+
233+
json_writer.add_field("success", success)
234+
235+
if not is_triton_interpret_set():
236+
gemm_registers = MatMulReduceScatterWgSpecialized.get_matmul_registers()
237+
gemm_spills = MatMulReduceScatterWgSpecialized.get_matmul_spills()
238+
json_writer.add_field("gemm_registers", gemm_registers)
239+
json_writer.add_field("gemm_spills", gemm_spills)
240+
241+
shmem.barrier()
242+
shmem.info("Validation completed")
243+
244+
if args["benchmark"]:
245+
MatMulReduceScatterWgSpecialized.set_debug(False)
246+
shmem.info("Benchmarking...")
247+
248+
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
249+
250+
triton_ms = iris.do_bench(run_experiment, shmem.barrier)
251+
triton_tflops = perf(triton_ms)
252+
253+
shmem.info(f"GEMM + ReduceScatter (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops")
254+
255+
json_writer.add_field("tflops", triton_tflops)
256+
json_writer.add_field("total_ms", triton_ms)
257+
258+
for k in ["gemm_rs"]:
259+
json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"])
260+
json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"])
261+
262+
shmem.barrier()
263+
264+
if rank == 0:
265+
json_writer.flush()
266+
json_writer.display()
267+
268+
if args["trace_tiles"] and rank == 0:
269+
gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3
270+
filename = f"gemm_tiles_reduce_scatter_trace_rank{rank}.json"
271+
timestamps.to_json(filename, gpu_freq)
272+
273+
shmem.barrier()
274+
dist.destroy_process_group()
275+
276+
277+
def main():
278+
args = parse_args()
279+
num_ranks = args["num_ranks"]
280+
281+
init_url = "tcp://127.0.0.1:29500"
282+
mp.spawn(
283+
fn=_worker,
284+
args=(num_ranks, init_url, args),
285+
nprocs=num_ranks,
286+
join=True,
287+
)
288+
289+
290+
if __name__ == "__main__":
291+
main()

0 commit comments

Comments
 (0)