Skip to content

Commit 3d00e8f

Browse files
[moe training] remove duplicate benchmark script (#2762)
1 parent f600b83 commit 3d00e8f

File tree

2 files changed

+5
-195
lines changed

2 files changed

+5
-195
lines changed

benchmarks/prototype/moe_training/benchmark_kernels.py

Lines changed: 0 additions & 193 deletions
This file was deleted.

benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ class Experiment:
4949

5050

5151
def get_configs() -> List[ExperimentConfig]:
52-
input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
53-
n_groups_list = [4, 8, 16]
52+
input_shapes = [(16640, 5120)] # (Mg, K)
53+
n_groups_list = [16, 128]
5454
high_precision_dtypes = [torch.bfloat16]
5555
configs = []
5656
for input_shape, n_groups, high_precision_dtype in itertools.product(
@@ -129,6 +129,7 @@ def run_triton(
129129

130130
# bench torch
131131
compiled_run_torch = torch.compile(run_torch)
132+
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
132133
torch_time_us = benchmark_cuda_function_in_microseconds(
133134
compiled_run_torch, input_row_major, input_col_major, offs
134135
)
@@ -152,6 +153,7 @@ def print_results(experiments: List[Experiment]):
152153
"high_precision_dtype",
153154
"torch_time_us",
154155
"triton_time_us",
156+
"triton_speedup",
155157
]
156158
rows = []
157159
for experiment in experiments:
@@ -165,6 +167,7 @@ def print_results(experiments: List[Experiment]):
165167
experiment.config.high_precision_dtype,
166168
experiment.result.torch_time_us,
167169
experiment.result.triton_time_us,
170+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
168171
]
169172
)
170173
print(tabulate(rows, headers=headers))

0 commit comments

Comments
 (0)