Skip to content

Commit f600b83

Browse files
[moe training] use llama4 shapes for kernel benchmarks (#2756)
1 parent 478c5f2 commit f600b83

File tree

6 files changed

+22
-19
lines changed

6 files changed

+22
-19
lines changed

benchmarks/float8/bench_grouped_mm.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run(
6464

6565
# Run bf16 torch._grouped_mm baseline.
6666
A = torch.randn(M, K, device=device, dtype=dtype)
67-
B = torch.randn(E, K, N, device=device, dtype=dtype)
67+
B = torch.randn(E, N, K, device=device, dtype=dtype)
6868
offs = generate_jagged_offs(E, M)
6969
print(f"offs: {offs}")
7070
ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks(
@@ -73,7 +73,7 @@ def run(
7373
use_gpu_kernel_time,
7474
torch._grouped_mm,
7575
A,
76-
B,
76+
B.transpose(-2, -1),
7777
offs,
7878
)
7979
print(
@@ -84,12 +84,7 @@ def run(
8484

8585
# Run scaled_grouped_mm.
8686
A_hp = torch.randn(M, K, device=device)
87-
B_hp_t = (
88-
torch.randn(E, K, N, device=device)
89-
.transpose(-2, -1)
90-
.contiguous()
91-
.transpose(-2, -1)
92-
)
87+
B_hp_t = torch.randn(E, N, K, device=device).transpose(-2, -1)
9388

9489
if recipe == "rowwise":
9590
# TODO: add e5m2

benchmarks/float8/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_name_to_moe_shapes_iter(
219219
N: Optional[int] = None,
220220
E: Optional[int] = None,
221221
):
222-
M = 8192 if M is None else M
222+
M = 16640 if M is None else M
223223
if shape_gen_name == "llama4_17bx16e":
224224
# num_experts=16, dim=5120
225225
names_to_shapes = {
@@ -232,8 +232,8 @@ def get_name_to_moe_shapes_iter(
232232
# num_experts=128, dim=5120
233233
names_to_shapes = {
234234
# M, K, N, E
235-
"moe.experts.w1": (M, 5120, 8192, 128),
236-
"moe.experts.w2": (M, 8192, 5120, 128),
235+
"moe.experts.w1": (M, 5120, 4 * 5120, 128),
236+
"moe.experts.w2": (M, 4 * 5120, 5120, 128),
237237
}
238238
return names_to_shapes.items()
239239
elif shape_gen_name == "custom":

benchmarks/prototype/moe_training/benchmark_kernels.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ class Experiment:
4949

5050

5151
def get_configs() -> List[ExperimentConfig]:
52-
input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
53-
n_groups_list = [4, 8, 16]
52+
input_shapes = [(16640, 5120)] # (Mg, K)
53+
n_groups_list = [16, 128]
5454
high_precision_dtypes = [torch.bfloat16]
5555
configs = []
5656
for input_shape, n_groups, high_precision_dtype in itertools.product(
@@ -129,6 +129,7 @@ def run_triton(
129129

130130
# bench torch
131131
compiled_run_torch = torch.compile(run_torch)
132+
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
132133
torch_time_us = benchmark_cuda_function_in_microseconds(
133134
compiled_run_torch, input_row_major, input_col_major, offs
134135
)
@@ -152,6 +153,7 @@ def print_results(experiments: List[Experiment]):
152153
"high_precision_dtype",
153154
"torch_time_us",
154155
"triton_time_us",
156+
"triton_speedup",
155157
]
156158
rows = []
157159
for experiment in experiments:
@@ -165,6 +167,7 @@ def print_results(experiments: List[Experiment]):
165167
experiment.config.high_precision_dtype,
166168
experiment.result.torch_time_us,
167169
experiment.result.triton_time_us,
170+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
168171
]
169172
)
170173
print(tabulate(rows, headers=headers))

benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ class Experiment:
4646

4747

4848
def get_configs() -> List[ExperimentConfig]:
49-
# Llama4 and DeepSeekV3 shapes
50-
input_shapes = [(8, 4096, 1024), (16, 5120 * 4, 5120)]
49+
# Llama4 shapes
50+
input_shapes = [
51+
(16, 8192, 5120), # w1, w3
52+
(16, 5120, 8192), # w2
53+
]
5154
high_precision_dtypes = [torch.bfloat16]
5255
configs = []
5356
for input_shape, high_precision_dtype in itertools.product(
@@ -117,6 +120,7 @@ def print_results(experiments: List[Experiment]):
117120
"input_shape",
118121
"torch_time_us",
119122
"triton_time_us",
123+
"triton_speedup",
120124
]
121125
rows = []
122126
for experiment in experiments:
@@ -126,6 +130,7 @@ def print_results(experiments: List[Experiment]):
126130
input_shape,
127131
experiment.result.torch_time_us,
128132
experiment.result.triton_time_us,
133+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
129134
]
130135
)
131136
print(tabulate(rows, headers=headers))

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
block_sizes_n = [32, 128, 512] # large dim (output_features)
3030
block_sizes_k = [32, 128, 512] # small dim (input_features)
3131
num_warps = [8]
32-
num_stages = [2, 3]
32+
num_stages = [2, 4]
3333
kernel_configs_2D = [
3434
triton.Config(
3535
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
}
3333

3434
block_sizes = [1, 16, 32, 64]
35-
block_sizes_iter = [32, 64, 128, 256]
36-
num_warps = [1, 4]
37-
num_stages = [2, 3]
35+
block_sizes_iter = [64, 128, 256]
36+
num_warps = [4]
37+
num_stages = [3]
3838
kernel_configs_2D = [
3939
triton.Config(
4040
{"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter},

0 commit comments

Comments
 (0)