Skip to content

Commit 2a5733b

Browse files
committed
[Benchmark] Add all reduce benchmark
stack-info: PR: #393, branch: joydddd/stack/21
1 parent a1ebf4f commit 2a5733b

File tree

5 files changed

+515
-0
lines changed

5 files changed

+515
-0
lines changed

benchmarks/distributed/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from __future__ import annotations
2+
3+
from .all_reduce import AllReduceBench as AllReduceBench

benchmarks/distributed/all_reduce.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
from typing import TYPE_CHECKING
5+
6+
import torch
7+
import torch.distributed as dist
8+
import torch.distributed._symmetric_memory as symm_mem
9+
10+
from .experiment_util import BenchmarkOperator
11+
from .experiment_util import ExperimentConfig
12+
13+
if TYPE_CHECKING:
14+
import argparse
15+
16+
BUILDTIN_SHAPES = [
17+
4093,
18+
4096,
19+
5000,
20+
8192,
21+
8193,
22+
16384,
23+
16380,
24+
16387,
25+
]
26+
LARGE_K_SHAPES = [2**exp for exp in range(15, 21)]
27+
28+
29+
class AllReduceBench(BenchmarkOperator):
30+
def gen_configs(self, args: argparse.Namespace) -> list[ExperimentConfig]:
31+
all_configs = []
32+
for sz in args.shape:
33+
all_configs.append(
34+
ExperimentConfig(
35+
shape=(sz,),
36+
dtype=args.dtype,
37+
backends=args.backend,
38+
device=self.device,
39+
)
40+
)
41+
42+
return all_configs
43+
44+
def gen_inputs(self, config: ExperimentConfig) -> tuple:
45+
input_tensor = symm_mem.empty(
46+
config.shape,
47+
dtype=config.dtype,
48+
device=config.device,
49+
)
50+
assert dist.group.WORLD is not None
51+
symm_mem.rendezvous(input_tensor, dist.group.WORLD.group_name)
52+
input_tensor = input_tensor.normal_()
53+
return (input_tensor,)
54+
55+
def additional_parser_args(
56+
self, parser: argparse.ArgumentParser
57+
) -> argparse.ArgumentParser:
58+
parser.add_argument(
59+
"--shape",
60+
type=int,
61+
nargs="+",
62+
default=BUILDTIN_SHAPES + LARGE_K_SHAPES,
63+
help="Tensor lengths",
64+
)
65+
return parser
66+
67+
def __init__(self) -> None:
68+
self.op_name = "allreduce"
69+
self.baseline = "nccl"
70+
super().__init__()
71+
72+
def nccl_ring(msg: torch.Tensor) -> torch.Tensor:
73+
dist.all_reduce(msg)
74+
return msg
75+
76+
assert dist.group.WORLD is not None
77+
78+
ALLREDUCE_DICT = {
79+
"multimem": functools.partial(
80+
torch.ops.symm_mem.multimem_all_reduce_,
81+
reduce_op="sum",
82+
group_name=dist.group.WORLD.group_name,
83+
),
84+
"oneshot": functools.partial(
85+
torch.ops.symm_mem.one_shot_all_reduce,
86+
reduce_op="sum",
87+
group_name=dist.group.WORLD.group_name,
88+
),
89+
"twoshot": functools.partial(
90+
torch.ops.symm_mem.two_shot_all_reduce_,
91+
reduce_op="sum",
92+
group_name=dist.group.WORLD.group_name,
93+
),
94+
"nccl": nccl_ring,
95+
"helion_oneshot": ("examples.all_reduce", "helion_one_shot_all_reduce"),
96+
"kraken_oneshot": ("kraken.all_reduce", "one_shot_all_reduce"),
97+
}
98+
self.backend_dict = ALLREDUCE_DICT

0 commit comments

Comments
 (0)