Skip to content

Commit d87a64a

Browse files
committed
[Benchmark] Add all gather matmul benchmark
stack-info: PR: #400, branch: joydddd/stack/22
1 parent 8c301ef commit d87a64a

File tree

5 files changed

+206
-26
lines changed

5 files changed

+206
-26
lines changed

benchmarks/distributed/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from __future__ import annotations
22

3+
from .all_gather_matmul import AGMatmulBench as AGMatmulBench
34
from .all_reduce import AllReduceBench as AllReduceBench
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
5+
import torch
6+
import torch.distributed as dist
7+
import torch.distributed._symmetric_memory as symm_mem
8+
9+
from .experiment_util import BenchmarkOperator
10+
from .experiment_util import ExperimentConfig
11+
12+
BUILDIN_SHAPES = [
13+
(256, 256, 256),
14+
(384, 384, 384),
15+
(512, 512, 512),
16+
(640, 640, 640),
17+
(768, 768, 768),
18+
(896, 896, 896),
19+
(1024, 1024, 1024),
20+
(1152, 1152, 1152),
21+
(1280, 1280, 1280),
22+
(1408, 1408, 1408),
23+
(1536, 1536, 1536),
24+
(1664, 1664, 1664),
25+
(1792, 1792, 1792),
26+
(1920, 1920, 1920),
27+
(2048, 2048, 2048),
28+
(2176, 2176, 2176),
29+
(2304, 2304, 2304),
30+
(2432, 2432, 2432),
31+
(2560, 2560, 2560),
32+
(2688, 2688, 2688),
33+
(2816, 2816, 2816),
34+
(2944, 2944, 2944),
35+
(3072, 3072, 3072),
36+
(3200, 3200, 3200),
37+
(3328, 3328, 3328),
38+
(3456, 3456, 3456),
39+
(3584, 3584, 3584),
40+
(3712, 3712, 3712),
41+
(3840, 3840, 3840),
42+
(3968, 3968, 3968),
43+
(4096, 4096, 4096),
44+
]
45+
46+
47+
class AGMatmulBench(BenchmarkOperator):
48+
def gen_configs(self, args: argparse.Namespace) -> list[ExperimentConfig]:
49+
all_configs = []
50+
for sz in args.shape:
51+
all_configs.append(
52+
ExperimentConfig(
53+
shape=sz,
54+
dtype=args.dtype,
55+
backends=args.backend,
56+
device=self.device,
57+
)
58+
)
59+
60+
return all_configs
61+
62+
def gen_inputs(self, config: ExperimentConfig) -> tuple:
63+
M, N, K = config.shape
64+
a = symm_mem.empty(
65+
(M, K),
66+
dtype=config.dtype,
67+
device=config.device,
68+
)
69+
b = (
70+
torch.randn((K, N), device=config.device, dtype=config.dtype)
71+
.T.contiguous()
72+
.T
73+
)
74+
assert dist.group.WORLD is not None
75+
symm_mem.rendezvous(a, dist.group.WORLD.group_name)
76+
return (a, b)
77+
78+
def additional_parser_args(
79+
self, parser: argparse.ArgumentParser
80+
) -> argparse.ArgumentParser:
81+
def matmul_shape_type(s: str) -> tuple[int, int, int]:
82+
try:
83+
M, N, K = map(int, s.split(","))
84+
return M, N, K
85+
except Exception as e:
86+
raise argparse.ArgumentTypeError(
87+
"Matmul shape must be M, N, K. (M, K) @ (K, N) -> (M, N)"
88+
) from e
89+
90+
parser.add_argument(
91+
"--shape",
92+
type=matmul_shape_type,
93+
nargs="+",
94+
default=BUILDIN_SHAPES,
95+
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
96+
)
97+
return parser
98+
99+
def __init__(self) -> None:
100+
self.op_name = "ag_matmul"
101+
self.baseline = "nccl"
102+
super().__init__()
103+
104+
def nccl_mem_ag_mm(
105+
a_shared: torch.Tensor, b: torch.Tensor
106+
) -> tuple[torch.Tensor, torch.Tensor]:
107+
from torch.distributed._functional_collectives import all_gather_tensor
108+
109+
a_gathered = all_gather_tensor(a_shared, 0, "0")
110+
return a_gathered, torch.matmul(a_gathered, b)
111+
112+
def torch_symm_mem_ag_mm(
113+
a_shared: torch.Tensor, b: torch.Tensor
114+
) -> tuple[torch.Tensor, torch.Tensor]:
115+
a_gathered, c = torch.ops.symm_mem.fused_all_gather_matmul(
116+
a_shared, [b], gather_dim=0, group_name=dist.group.WORLD.group_name
117+
)
118+
return a_gathered, c[0]
119+
120+
assert dist.group.WORLD is not None
121+
122+
AG_MATMUL_DICT = {
123+
"nccl": nccl_mem_ag_mm,
124+
"torch_symm_mem": torch_symm_mem_ag_mm,
125+
"helion": ("examples.all_gather_matmul", "helion_all_gather_matmul"),
126+
"kraken": ("kraken.all_gather", "all_gather_matmul"),
127+
}
128+
self.backend_dict = AG_MATMUL_DICT

benchmarks/distributed/experiment_util.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def clone_symm_mem_tensor(tensor: torch.Tensor) -> torch.Tensor:
3737
device=tensor.device,
3838
)
3939
assert dist.group.WORLD is not None
40-
symm_mem.rendezvous(symm_mem_tensor, dist.group.WORLD.group_name)
40+
try:
41+
symm_mem.rendezvous(symm_mem_tensor, dist.group.WORLD.group_name)
42+
except RuntimeError:
43+
print(tensor.shape)
4144
symm_mem_tensor.copy_(tensor)
4245
return symm_mem_tensor
4346

@@ -96,7 +99,7 @@ class BenchmarkOperator:
9699
--nnodes 1 --nproc-per-node 8 \
97100
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
98101
--no_python python3 \
99-
benchmarks/run_distributed.py
102+
benchmarks/run_distributed.py <op>
100103
"""
101104

102105
experiments: list[Experiment]
@@ -131,6 +134,12 @@ def parse_args(self) -> argparse.Namespace:
131134
description=f"Run benchmark for {self.__name__}. " + self.help_str
132135
)
133136

137+
parser.add_argument(
138+
"op",
139+
type=str,
140+
help="Operator to benchmark. ",
141+
)
142+
134143
parser.add_argument(
135144
"--backend",
136145
type=str,
@@ -153,6 +162,8 @@ def parse_args(self) -> argparse.Namespace:
153162
self.args = parser.parse_args()
154163
self.args.dtype = getattr(torch, self.args.dtype)
155164

165+
assert self.args.op == self.op_name
166+
156167
return self.args
157168

158169
def __init__(self) -> None:
@@ -168,7 +179,6 @@ def __init__(self) -> None:
168179

169180
self.device = torch.device(f"cuda:{self.local_rank}")
170181
torch.cuda.set_device(self.device)
171-
dist.init_process_group("nccl")
172182
torch.manual_seed(42 + self.local_rank)
173183

174184
self.experiments = []
@@ -292,35 +302,42 @@ def get_results(self, metric: str = "speedup") -> defaultdict | None:
292302

293303
def run_experiment(self, config: ExperimentConfig) -> dict[str, float]:
294304
if self.baseline not in config.backends:
295-
backends = config.backends.append(self.baseline)
305+
backends = [*config.backends, self.baseline]
296306
else:
297307
backends = config.backends
298308

299309
gloden_inp = self.gen_inputs(config)
300-
inputs = {backend: clone_inputs(gloden_inp) for backend in backends} # pyright: ignore[reportOptionalIterable]
301310

302311
gloden_fn = self.fn_dict[self.baseline]
303312
assert gloden_fn is not None
304313

305314
gloden_o = gloden_fn(*gloden_inp)
306315

307316
results = {}
308-
for backend in backends: # pyright: ignore[reportOptionalIterable]
317+
for backend in backends:
309318
fn = self.fn_dict[backend]
310319
if fn is None:
311320
results[backend] = float("nan")
312321
continue
313-
inp = inputs[backend]
322+
inp = clone_inputs(gloden_inp)
314323
target_fn = functools.partial(fn, *inp)
315324
try:
316325
test_o = target_fn()
317326
except RuntimeError:
318327
results[backend] = float("nan")
319328
continue
329+
except AssertionError:
330+
results[backend] = float("nan")
331+
continue
320332
torch.testing.assert_close(test_o, gloden_o, atol=1e-1, rtol=1e-1)
321333

322334
results[backend] = benchmark_distributed(
323335
target_fn, profile_ranks=[self.MASTER_RANK]
324336
)
337+
del test_o
338+
del inp
339+
340+
del gloden_inp
341+
del gloden_o
325342

326343
return results

benchmarks/run_distributed.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,47 @@
11
from __future__ import annotations
22

3-
from benchmarks.distributed import AllReduceBench as AllReduceBenchmark
3+
import sys
4+
5+
from benchmarks.distributed import AGMatmulBench as AGMatmulBench
6+
from benchmarks.distributed import AllReduceBench as AllReduceBench
47
import torch.distributed as dist
58

9+
OP_BENCH = {
10+
"allreduce": AllReduceBench,
11+
"ag_matmul": AGMatmulBench,
12+
}
13+
614

715
def main() -> None:
8-
bench = AllReduceBenchmark()
9-
bench.run()
10-
bench.print_results(metric="time_us")
16+
try:
17+
dist.init_process_group("nccl")
18+
except ValueError:
19+
print("""
20+
Failed to initialize process group. Are you runing with torchrun?
21+
run distributed benchmark with:
22+
torchrun \
23+
--nnodes 1 --nproc-per-node 8 \
24+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
25+
--no_python python3 \
26+
benchmarks/run_distributed.py <op>
27+
""")
28+
sys.exit(1)
29+
30+
if len(sys.argv) < 2:
31+
print("Usage: python3 benchmarks/run_distributed.py <op>")
32+
print(f"Available ops: {OP_BENCH.keys()}")
33+
sys.exit(1)
34+
35+
op = sys.argv[1]
36+
37+
if op not in OP_BENCH:
38+
print(f"Unknown op: {op}")
39+
print(f"value ops: {OP_BENCH.keys()}")
40+
sys.exit(1)
41+
42+
op_bench = OP_BENCH[op]()
43+
op_bench.run()
44+
op_bench.print_results(metric="time_us")
1145

1246
dist.destroy_process_group()
1347

examples/all_gather_matmul.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def copy_engine_all_gather_w_progress(
4242

4343
with torch.cuda.stream(backend_stream):
4444
for step in range(world_size):
45-
src_rank = (rank + step + 1) % world_size
45+
src_rank = (rank + step) % world_size
4646
for split_id in range(splits_per_rank):
4747
src_buf = symm_mem_hdl.get_buffer(
4848
src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
@@ -66,7 +66,9 @@ def copy_engine_all_gather_w_progress(
6666
block_sizes=[128, 256, 64],
6767
num_warps=8,
6868
num_stages=3,
69-
indexing="block_ptr",
69+
indexing="tensor_descriptor",
70+
pid_type="persistent_interleaved",
71+
l2_groupings=[4],
7072
),
7173
static_shapes=True,
7274
)
@@ -75,7 +77,7 @@ def helion_matmul_w_progress(
7577
a_shared: torch.Tensor,
7678
b: torch.Tensor,
7779
progress: torch.Tensor,
78-
SPLITS_PER_RANK: int,
80+
SPLITS_PER_RANK: hl.constexpr,
7981
RANK: int,
8082
) -> torch.Tensor:
8183
M, K = a.size()
@@ -90,20 +92,18 @@ def helion_matmul_w_progress(
9092

9193
for tile_m, tile_n in hl.tile([M, N]):
9294
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
93-
hl.wait(
94-
progress,
95-
[
96-
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
97-
],
98-
signal=1,
99-
)
95+
# TODO(joydddd): natively suppor starting range from non_zero index.
96+
comm_block_id = ((tile_m.begin + RANK * M_per_rank) % M) // (
97+
M_per_rank // SPLITS_PER_RANK
98+
) # pyright: ignore[reportOperatorIssue]
99+
hl.wait(progress, [comm_block_id], signal=1)
100100
for tile_k in hl.tile(K):
101101
# TODO(joydddd): use a_shared and skip barrier when data is available on local rank.
102-
# if tile_k.begin // M_per_rank == RANK:
103-
# acc = torch.addmm(acc, a_shared[tile_m.index - RANK * M_per_rank, tile_k], b[tile_k, tile_n])
104-
# else:
105-
# hl.wait(progress, [tile_m.begin // (M_per_rank // SPLITS_PER_RANK),], signal=1, update=None, op="ld", scope="gpu", sem="acquire")
106-
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
102+
acc = torch.addmm(
103+
acc,
104+
a[((tile_m.index + RANK * M_per_rank) % M), tile_k],
105+
b[tile_k, tile_n],
106+
)
107107
out[tile_m, tile_n] = acc
108108
return out
109109

0 commit comments

Comments
 (0)