Skip to content

Commit aed4f84

Browse files
[moe training] refactor to share benchmarking and profiling utils (#2767)
1 parent d38e9b6 commit aed4f84

File tree

3 files changed

+117
-73
lines changed

3 files changed

+117
-73
lines changed

benchmarks/prototype/moe_training/benchmark_moe_layer.py renamed to benchmarks/prototype/moe_training/benchmark_moe_fsdp.py

Lines changed: 54 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import argparse
1515
import copy
1616
import os
17-
import statistics
18-
from time import perf_counter_ns
1917

2018
import pytest
2119
import torch
@@ -24,6 +22,11 @@
2422
from torch.distributed._composable.fsdp import fully_shard
2523
from torch.nn import functional as F
2624

25+
from benchmarks.prototype.moe_training.utils import (
26+
bench_fwd_bwd_microseconds,
27+
profile_fn,
28+
)
29+
2730
# this feature requires CUDA and SM89+
2831
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
2932
pytest.skip(
@@ -48,8 +51,12 @@
4851
)
4952

5053

51-
def bench_moe_float8_training_fsdp(enable_profile=False):
54+
def bench_moe_float8_training_fsdp(
55+
recipe_name: str, enable_profile: bool, use_compile: bool
56+
):
5257
assert torch.cuda.is_available()
58+
assert recipe_name in ["fp8_rowwise", "mxfp8"]
59+
recipe = MoEScalingType[recipe_name.upper()]
5360

5461
# setup distributed for fsdp
5562
setup_distributed()
@@ -62,15 +69,19 @@ def bench_moe_float8_training_fsdp(enable_profile=False):
6269
init_std = 0.02
6370
device = torch.device("cuda")
6471

65-
# reference bf16 MoE
66-
dim, hidden_dim = 5120, 4 * 5120
72+
# reference bf16 MoE using llama4 shapes
73+
dim, hidden_dim = 5120, 8192
6774
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
6875
torch.manual_seed(42)
6976
ref_model.init_weights(init_std, device)
7077

7178
# target MoE for testing conversion
7279
model = copy.deepcopy(ref_model)
7380

81+
# Token group alignment size must be 16 for fp8 rowwise training
82+
alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16
83+
set_token_group_alignment_size_m(alignment_size)
84+
7485
# assert starting params are identical for both models
7586
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
7687
assert torch.equal(param1, param2)
@@ -83,15 +94,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
8394
return False
8495

8596
# quantize test model
86-
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
97+
config = MoETrainingConfig(scaling_type=recipe)
8798
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
8899

89100
# FSDP2
90101
fully_shard(model)
91102
fully_shard(ref_model)
92103

93104
# inputs (llama4 shapes)
94-
batch, seq = 1, 8192
105+
batch, seq = 1, 16640
95106
ref_x = torch.randn(
96107
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
97108
)
@@ -104,70 +115,34 @@ def warmup(model, input):
104115
loss.backward()
105116
torch.cuda.synchronize()
106117

107-
def bench_fn_microseconds(model, input):
108-
labels = torch.ones_like(input)
109-
times = []
110-
for _ in range(10):
111-
start_ns = perf_counter_ns()
112-
out = model(input)
113-
loss = F.mse_loss(out, labels)
114-
loss.backward()
115-
torch.cuda.synchronize()
116-
end_ns = perf_counter_ns()
117-
duration_us = (end_ns - start_ns) / 1000
118-
times.append(duration_us)
119-
return statistics.median(times)
120-
121-
def profile_fn(model, input, profile_name="profile"):
122-
# Only profile on rank 0
123-
if torch.distributed.get_rank() == 0:
124-
labels = torch.ones_like(input)
125-
wait, warmup, active = 1, 3, 1
126-
total_steps = wait + warmup + active
127-
with torch.profiler.profile(
128-
activities=[
129-
torch.profiler.ProfilerActivity.CPU,
130-
torch.profiler.ProfilerActivity.CUDA,
131-
],
132-
schedule=torch.profiler.schedule(
133-
wait=wait, warmup=warmup, active=active, repeat=0
134-
),
135-
record_shapes=True,
136-
with_stack=True,
137-
) as prof:
138-
for _ in range(total_steps):
139-
out = model(input)
140-
loss = F.mse_loss(out, labels)
141-
loss.backward()
142-
prof.step()
143-
144-
# Save profiler results
145-
prof.export_chrome_trace(f"{profile_name}.json")
146-
print(f"Saved: {profile_name}.json")
147-
148-
# Compile models
149-
ref_model = torch.compile(ref_model, fullgraph=False)
150-
model = torch.compile(model, fullgraph=False)
151-
152-
print("Benchmarking MoE with FSDP2 using bf16 training")
153-
warmup(ref_model, ref_x)
154-
bf16_us = bench_fn_microseconds(ref_model, ref_x)
155-
print(f"bf16 time: {bf16_us} us")
156-
if enable_profile:
157-
print("Profiling bf16 model")
158-
profile_fn(ref_model, ref_x, profile_name="bf16_profile")
118+
labels = torch.ones_like(x)
159119

160-
# Token group alignment size must be 16 for fp8 rowwise training
161-
set_token_group_alignment_size_m(16)
162-
163-
print("Benchmarking MoE with FSDP2 using fp8 rowwise training")
164-
warmup(model, x)
165-
fp8_us = bench_fn_microseconds(model, x)
166-
print(f"fp8 time: {fp8_us} us")
120+
# TODO: bench with fullgraph=True if/when it is supported
121+
bf16_us = bench_fwd_bwd_microseconds(
122+
ref_model,
123+
ref_x,
124+
labels=labels,
125+
use_compile=use_compile,
126+
fullgraph=False,
127+
)
128+
print(f"BF16 time: {bf16_us} us")
129+
if enable_profile:
130+
print("Profiling bf16 training")
131+
profile_fn(ref_model, ref_x, labels=labels, profile_name="bf16_profile")
132+
133+
scaled_us = bench_fwd_bwd_microseconds(
134+
model,
135+
x,
136+
labels=labels,
137+
use_compile=use_compile,
138+
fullgraph=False,
139+
)
140+
print(f"Scaled time: {scaled_us} us")
167141
if enable_profile:
168-
print("Profiling fp8 model")
169-
profile_fn(model, x, profile_name="fp8_profile")
142+
print("Profiling quantized training")
143+
profile_fn(model, x, labels=labels, profile_name=f"{recipe_name}_profile")
170144

145+
print(f"Speedup: {bf16_us / scaled_us:.3f}x")
171146
dist.destroy_process_group()
172147

173148

@@ -185,5 +160,15 @@ def setup_distributed():
185160
action="store_true",
186161
help="Enable PyTorch profiling and save results to file",
187162
)
163+
parser.add_argument("--recipe", type=str, help="[fp8_rowwise, mxfp8]")
164+
parser.add_argument(
165+
"--compile",
166+
action="store_true",
167+
help="use torch.compile",
168+
)
188169
args = parser.parse_args()
189-
bench_moe_float8_training_fsdp(enable_profile=args.profile)
170+
bench_moe_float8_training_fsdp(
171+
recipe_name=args.recipe,
172+
enable_profile=args.profile,
173+
use_compile=args.compile,
174+
)

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py renamed to benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from tabulate import tabulate
1414
from tqdm import tqdm
15-
from utils import bench_fwd_bwd_microseconds
15+
from utils import bench_fwd_bwd_microseconds, profile_fn
1616

1717
from torchao.prototype.moe_training import _scaled_grouped_mm
1818
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
@@ -47,7 +47,7 @@ class Experiment:
4747

4848
def get_configs() -> List[ExperimentConfig]:
4949
A_shapes = [(16640, 5120)]
50-
B_shapes = [(16, 8192, 5120), (128, 8192, 5120)]
50+
B_shapes = [(16, 8192, 5120)]
5151
recipes = [MoEScalingType.FP8_ROWWISE]
5252
high_precision_dtypes = [torch.bfloat16]
5353
configs = []
@@ -106,6 +106,16 @@ def run_experiment(
106106
labels=labels,
107107
use_compile=args.compile,
108108
)
109+
if args.profile:
110+
profile_fn(
111+
torch._grouped_mm,
112+
A,
113+
B_t,
114+
offs,
115+
labels=labels,
116+
use_compile=args.compile,
117+
profile_name="bf16_profile",
118+
)
109119

110120
# benchmark scaled grouped mm with dynamic fp8 rowwise quant
111121
fp8_us = bench_fwd_bwd_microseconds(
@@ -117,6 +127,17 @@ def run_experiment(
117127
labels=labels,
118128
use_compile=args.compile,
119129
)
130+
if args.profile:
131+
profile_fn(
132+
_scaled_grouped_mm,
133+
A,
134+
B_t,
135+
offs,
136+
scaling_type=config.recipe,
137+
labels=labels,
138+
use_compile=args.compile,
139+
profile_name="scaled_profile",
140+
)
120141

121142
return ExperimentResult(
122143
bf16_us=round(bf16_us, 3),
@@ -164,5 +185,6 @@ def main(args: argparse.Namespace):
164185
if __name__ == "__main__":
165186
arg_parser = argparse.ArgumentParser()
166187
arg_parser.add_argument("--compile", action="store_true")
188+
arg_parser.add_argument("--profile", action="store_true")
167189
args = arg_parser.parse_args()
168190
main(args)

benchmarks/prototype/moe_training/utils.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from torch.nn import functional as F
66

77

8-
def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwargs):
8+
def bench_fwd_bwd_microseconds(
9+
fn, *args, labels=None, use_compile=False, fullgraph=True, **kwargs
10+
):
911
assert labels is not None
10-
fn = torch.compile(fn, fullgraph=False) if use_compile else fn
12+
fn = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn
1113
times = []
1214
for _ in range(10):
1315
start_ns = perf_counter_ns()
@@ -19,3 +21,38 @@ def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwar
1921
duration_us = (end_ns - start_ns) / 1000
2022
times.append(duration_us)
2123
return statistics.median(times)
24+
25+
26+
def profile_fn(
27+
fn,
28+
*args,
29+
labels=None,
30+
use_compile=False,
31+
fullgraph=True,
32+
profile_name="profile",
33+
**kwargs,
34+
):
35+
assert labels is not None
36+
fn = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn
37+
wait, warmup, active = 1, 3, 1
38+
total_steps = wait + warmup + active
39+
with torch.profiler.profile(
40+
activities=[
41+
torch.profiler.ProfilerActivity.CPU,
42+
torch.profiler.ProfilerActivity.CUDA,
43+
],
44+
schedule=torch.profiler.schedule(
45+
wait=wait, warmup=warmup, active=active, repeat=0
46+
),
47+
record_shapes=True,
48+
with_stack=True,
49+
) as prof:
50+
for _ in range(total_steps):
51+
out = fn(*args, **kwargs)
52+
loss = F.mse_loss(out, labels)
53+
loss.backward()
54+
prof.step()
55+
56+
# Save profiler results
57+
prof.export_chrome_trace(f"{profile_name}.json")
58+
print(f"Saved: {profile_name}.json")

0 commit comments

Comments
 (0)