Skip to content

Commit d38e9b6

Browse files
[moe training] update bench script to compare fp8 dynamic quant scaled_grouped_mm fwd+bwd against bf16 (#2765)
1 parent 3d00e8f commit d38e9b6

File tree

5 files changed

+91
-65
lines changed

5 files changed

+91
-65
lines changed

benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,13 @@ def run_torch(input_tensor: torch.Tensor):
8787
return out
8888

8989
def run_triton(input_tensor: torch.Tensor):
90-
_ = triton_fp8_rowwise_3d_transpose_rhs(
90+
out = triton_fp8_rowwise_3d_transpose_rhs(
9191
input_tensor,
9292
output_dtype=torch.float8_e4m3fn,
9393
round_scales_to_power_of_2=True,
9494
)
9595
torch.cuda.synchronize()
96+
return out
9697

9798
# bench torch
9899
compiled_run_torch = torch.compile(run_torch)

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77
import argparse
88
import itertools
9-
import time
109
from dataclasses import dataclass
1110
from typing import List
1211

1312
import torch
1413
from tabulate import tabulate
1514
from tqdm import tqdm
15+
from utils import bench_fwd_bwd_microseconds
1616

1717
from torchao.prototype.moe_training import _scaled_grouped_mm
18+
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
19+
from torchao.prototype.moe_training.utils import generate_jagged_offs
1820

1921
device = torch.device("cuda")
2022

@@ -27,11 +29,14 @@ class ExperimentConfig:
2729
high_precision_dtype: torch.dtype
2830
A_shape: tuple[int]
2931
B_shape: tuple[int]
32+
recipe: MoEScalingType
3033

3134

3235
@dataclass(frozen=True)
3336
class ExperimentResult:
34-
time_us: float
37+
bf16_us: float
38+
fp8_us: float
39+
fp8_speedup: float
3540

3641

3742
@dataclass(frozen=True)
@@ -41,19 +46,22 @@ class Experiment:
4146

4247

4348
def get_configs() -> List[ExperimentConfig]:
44-
A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)]
45-
B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)]
49+
A_shapes = [(16640, 5120)]
50+
B_shapes = [(16, 8192, 5120), (128, 8192, 5120)]
51+
recipes = [MoEScalingType.FP8_ROWWISE]
4652
high_precision_dtypes = [torch.bfloat16]
4753
configs = []
48-
for A_shape, B_shape, high_precision_dtype in itertools.product(
54+
for A_shape, B_shape, recipe, high_precision_dtype in itertools.product(
4955
A_shapes,
5056
B_shapes,
57+
recipes,
5158
high_precision_dtypes,
5259
):
5360
configs.append(
5461
ExperimentConfig(
5562
A_shape=A_shape,
5663
B_shape=B_shape,
64+
recipe=recipe,
5765
high_precision_dtype=high_precision_dtype,
5866
)
5967
)
@@ -83,47 +91,47 @@ def run_experiment(
8391
# - the transposed tensor in col-major format with groups along the row dimension,
8492
# which represents the right operand.
8593
n_groups = config.B_shape[0]
86-
group_size = A.shape[0] // n_groups
87-
offs = torch.arange(
88-
group_size,
89-
group_size * n_groups + 1,
90-
group_size,
91-
device=device,
92-
dtype=torch.int32,
93-
)
94+
offs = generate_jagged_offs(n_groups, A.shape[0], multiple_of=16)
9495

95-
def warmup(func, *args, **kwargs):
96-
for _ in range(10):
97-
func(*args, **kwargs)
96+
labels = torch.ones(
97+
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
98+
)
9899

99-
def forward_backward(A, B_t, offs):
100-
out = _scaled_grouped_mm(
101-
A,
102-
B_t,
103-
offs=offs,
104-
out_dtype=torch.bfloat16,
105-
)
106-
out.sum().backward()
107-
torch.cuda.synchronize()
100+
# benchmark bf16 grouped mm
101+
bf16_us = bench_fwd_bwd_microseconds(
102+
torch._grouped_mm,
103+
A,
104+
B_t,
105+
offs,
106+
labels=labels,
107+
use_compile=args.compile,
108+
)
108109

109-
# benchmark torch
110-
torch_func = torch.compile(forward_backward) if args.compile else forward_backward
111-
warmup(torch_func, A, B_t, offs)
112-
start_time_ns = time.perf_counter_ns()
113-
torch_func(A, B_t, offs)
114-
torch_time_ns = time.perf_counter_ns() - start_time_ns
115-
time_us = torch_time_ns / 1e3
110+
# benchmark scaled grouped mm with dynamic fp8 rowwise quant
111+
fp8_us = bench_fwd_bwd_microseconds(
112+
_scaled_grouped_mm,
113+
A,
114+
B_t,
115+
offs,
116+
scaling_type=config.recipe,
117+
labels=labels,
118+
use_compile=args.compile,
119+
)
116120

117121
return ExperimentResult(
118-
time_us=round(time_us, 3),
122+
bf16_us=round(bf16_us, 3),
123+
fp8_us=round(fp8_us, 3),
124+
fp8_speedup=round(bf16_us / fp8_us, 3),
119125
)
120126

121127

122128
def print_results(experiments: List[Experiment]):
123129
headers = [
124130
"A_shape",
125131
"B_shape",
126-
"time_us",
132+
"bf16_time_us",
133+
"scaled_time_us",
134+
"fp8_speedup",
127135
]
128136
rows = []
129137
for experiment in experiments:
@@ -133,7 +141,9 @@ def print_results(experiments: List[Experiment]):
133141
[
134142
A_shape,
135143
B_shape,
136-
experiment.result.time_us,
144+
experiment.result.bf16_us,
145+
experiment.result.fp8_us,
146+
f"{experiment.result.fp8_speedup}x",
137147
]
138148
)
139149
print(tabulate(rows, headers=headers))
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import statistics
2+
from time import perf_counter_ns
3+
4+
import torch
5+
from torch.nn import functional as F
6+
7+
8+
def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwargs):
9+
assert labels is not None
10+
fn = torch.compile(fn, fullgraph=False) if use_compile else fn
11+
times = []
12+
for _ in range(10):
13+
start_ns = perf_counter_ns()
14+
out = fn(*args, **kwargs)
15+
loss = F.mse_loss(out, labels)
16+
loss.backward()
17+
torch.cuda.synchronize()
18+
end_ns = perf_counter_ns()
19+
duration_us = (end_ns - start_ns) / 1000
20+
times.append(duration_us)
21+
return statistics.median(times)

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def triton_fp8_rowwise_3d_transpose_rhs(
5151
) -> Tuple[torch.Tensor, torch.Tensor]:
5252
assert hp_tensor.ndim == 3, "input tensor must be 3D"
5353

54-
num_elements = hp_tensor.numel()
5554
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
5655
tl_output_dtype = FP8_DTYPE_MAP[output_dtype]
5756

@@ -89,7 +88,6 @@ def triton_fp8_rowwise_3d_transpose_rhs(
8988
e,
9089
n,
9190
k,
92-
num_elements,
9391
fp8_dtype_min,
9492
fp8_dtype_max,
9593
tl_input_dtype,
@@ -113,7 +111,6 @@ def triton_fp8_rowwise_3d_transpose_rhs(
113111
e,
114112
n,
115113
k,
116-
num_elements,
117114
fp8_dtype_min,
118115
fp8_dtype_max,
119116
tl_input_dtype,
@@ -138,20 +135,19 @@ def _fake_triton_fp8_rowwise_3d_transpose_rhs(
138135
return output_buffer, scales_buffer
139136

140137

141-
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
138+
@triton.autotune(configs=kernel_configs_2D, key=["K", "N"])
142139
@triton.jit
143140
def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
144141
input_ptr,
145-
stride_input_dim0: int,
146-
stride_input_dim1: int,
147-
stride_input_dim2: int,
142+
stride_input_dim0: tl.int64,
143+
stride_input_dim1: tl.int64,
144+
stride_input_dim2: tl.int64,
148145
scales_ptr,
149146
stride_scales_dim0: int,
150147
stride_scales_dim1: int,
151148
E: int,
152149
N: int,
153150
K: int,
154-
num_elements: int,
155151
fp8_dtype_min: tl.constexpr,
156152
fp8_dtype_max: tl.constexpr,
157153
input_dtype: tl.constexpr,
@@ -202,20 +198,19 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
202198
@triton.jit
203199
def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel(
204200
input_ptr,
205-
stride_input_dim0: int,
206-
stride_input_dim1: int,
207-
stride_input_dim2: int,
201+
stride_input_dim0: tl.int64,
202+
stride_input_dim1: tl.int64,
203+
stride_input_dim2: tl.int64,
208204
output_ptr,
209-
stride_output_dim0: int,
210-
stride_output_dim1: int,
211-
stride_output_dim2: int,
205+
stride_output_dim0: tl.int64,
206+
stride_output_dim1: tl.int64,
207+
stride_output_dim2: tl.int64,
212208
scales_ptr,
213209
stride_scales_dim0: int,
214210
stride_scales_dim1: int,
215211
E: int,
216212
N: int,
217213
K: int,
218-
num_elements: int,
219214
fp8_dtype_min: tl.constexpr,
220215
fp8_dtype_max: tl.constexpr,
221216
input_dtype: tl.constexpr,

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _scaled_grouped_mm(
4848
"""
4949
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
5050
if scaling_type == MoEScalingType.FP8_ROWWISE:
51-
print("Using fp8 rowwise scaled_grouped_mm")
51+
# print("Using fp8 rowwise scaled_grouped_mm")
5252
return _Float8GroupedMM.apply(
5353
A,
5454
B_t,
@@ -140,17 +140,8 @@ def forward(
140140
B_t_scaled = B_t.to(torch.float32) * B_t_scales
141141
B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
142142

143-
# Precompute non-transposed B column-major for backward, to save memory by storing the
144-
# low precision B tensor instead of the high precision B tensor.
145-
# In the backward this is needed for grad_A: grad_output @ B.
146-
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
147-
B_t._data,
148-
output_dtype=torch.float8_e4m3fn,
149-
round_scales_to_power_of_2=True,
150-
)
151-
152143
# Store what we need for backward.
153-
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
144+
ctx.save_for_backward(A, B_t, offs)
154145
ctx.out_dtype = out_dtype
155146

156147
# Perform scaled grouped GEMM and return result.
@@ -179,7 +170,7 @@ def forward(
179170

180171
@staticmethod
181172
def backward(ctx, grad_output: torch.Tensor):
182-
A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors
173+
A, B_t, offs = ctx.saved_tensors
183174
out_dtype = ctx.out_dtype
184175

185176
# Convert grad_output to float8, row-major for left operand of grouped GEMM
@@ -199,6 +190,14 @@ def backward(ctx, grad_output: torch.Tensor):
199190
grad_output_scaled, torch.float8_e4m3fn
200191
)
201192

193+
# Compute B fp8 column-major for right operand of grouped GEMM:
194+
# grad_A = grad_output @ B.
195+
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
196+
B_t._data if hasattr(B_t, "_data") else B_t,
197+
output_dtype=torch.float8_e4m3fn,
198+
round_scales_to_power_of_2=True,
199+
)
200+
202201
# Compute grad_A.
203202
# grad_A = grad_output @ B
204203
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
@@ -217,8 +216,8 @@ def backward(ctx, grad_output: torch.Tensor):
217216
grad_A = torch._scaled_grouped_mm(
218217
grad_output_fp8_row_major,
219218
B_fp8_col_major,
220-
grad_output_scales.squeeze().reciprocal(),
221-
B_scales.squeeze().reciprocal(),
219+
grad_output_scales.reciprocal(),
220+
B_scales.reciprocal(),
222221
offs,
223222
out_dtype=out_dtype,
224223
use_fast_accum=True,

0 commit comments

Comments
 (0)