Skip to content

Commit d635ba8

Browse files
authored
[ENH] Benchmark entry point outputs, cli args (#4121)
1 parent 3d23844 commit d635ba8

File tree

7 files changed

+604
-231
lines changed

7 files changed

+604
-231
lines changed

benchmarks/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def get_git_commit_hash(length=8):
116116
install_requires=[
117117
"torch>=2.6",
118118
"pandas",
119+
"scipy",
119120
"psutil",
120121
"tabulate",
121122
"matplotlib",

benchmarks/tests/test_entry_point.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import pytest
44

55
from triton_kernels_benchmark.benchmark_testing import MarkArgs
6+
from triton_kernels_benchmark.becnhmark_config_templates import CONFIGS
67
from triton_kernels_benchmark.benchmark_utils import BenchmarkCategory, BenchmarkConfigs
78

89
ALL_CATEGORIES = {cat.value for cat in BenchmarkCategory}
10+
ALL_CONFIGS = {config.key: config for config in CONFIGS}
911

1012

1113
@pytest.mark.parametrize(
@@ -20,7 +22,7 @@
2022
"providers_count",
2123
),
2224
(
23-
[True, set(), True, ALL_CATEGORIES, [], None, lambda x: x > 1, lambda x: x > 1],
25+
[True, ALL_CONFIGS, True, ALL_CATEGORIES, [], None, lambda x: x > 1, lambda x: x > 1],
2426
[True, {"softmax", "gemm"}, True, ALL_CATEGORIES, [], None, lambda x: x > 1, lambda x: x > 1],
2527
[True, {"softmax", "gemm"}, True, {"core", "gemm", "softmax"}, [], None, lambda x: x > 1, lambda x: x > 1],
2628
[False, {"softmax"}, False, {"optional"}, ["triton"], AssertionError, None, None],
@@ -47,6 +49,8 @@ def benchmark_configs():
4749
select_all=select_all,
4850
categories_filter=categories_filter,
4951
providers_filter=providers_filter,
52+
json_output=False,
53+
detailed_output=False,
5054
tag="",
5155
)
5256

@@ -57,6 +61,6 @@ def benchmark_configs():
5761
configs = benchmark_configs().configs
5862
benchmark_configs().run()
5963
assert configs_count(len(configs))
60-
providers_counts = [len(config.config_summary.selected_providers) for config in configs]
64+
providers_counts = [len(config.selected_providers) for config in configs]
6165
assert providers_count(max(providers_counts))
6266
assert providers_count(min(providers_counts))

benchmarks/tests/test_mocks.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Optional
2+
3+
import re
4+
5+
import io
6+
7+
import pytest
8+
9+
import pandas as pd
10+
11+
from triton_kernels_benchmark.benchmark_utils import BenchmarkCategory, BenchmarkConfigs
12+
13+
ALL_CATEGORIES = {cat.value for cat in BenchmarkCategory}
14+
15+
SOFTMAX_PERFORMANCE_CSV = """
16+
N,Triton-GB/s,XeTLA-GB/s,Triton-GB/s-min,XeTLA-GB/s-min,Triton-GB/s-max,XeTLA-GB/s-max,Triton-TFlops,XeTLA-TFlops,Triton-TFlops-min,XeTLA-TFlops-min,Triton-TFlops-max,XeTLA-TFlops-max,Triton-CV,XeTLA-CV,datetime,run_counter
17+
256.000000,473.397771,568.333815,90.083848,514.007860,494.611303,582.542232,0.473398,0.568334,0.090084,0.514008,0.494611,0.582542,0.019154,0.018093,2025-05-05 21:45:29.943213,1
18+
1024.000000,683.111432,541.549931,672.164101,537.731297,689.852609,548.992673,0.683111,0.541550,0.672164,0.537731,0.689853,0.548993,0.006031,0.004731,2025-05-05 21:45:29.943213,1
19+
2048.000000,677.320009,726.915809,672.164101,708.497308,683.111380,825.650389,0.677320,0.726916,0.672164,0.708497,0.683111,0.825650,0.003426,0.018620,2025-05-05 21:45:29.943213,1
20+
4096.000000,627.302921,477.032066,616.809404,474.468764,641.330889,488.846612,0.627303,0.477032,0.616809,0.474469,0.641331,0.488847,0.008189,0.003735,2025-05-05 21:45:29.943213,1
21+
8192.000000,679.033333,611.916382,665.234595,604.802311,762.600731,637.916958,0.679033,0.611916,0.665235,0.604802,0.762601,0.637917,0.016350,0.008740,2025-05-05 21:45:29.943213,1
22+
16384.000000,712.219329,677.833087,703.447226,661.562161,760.871449,688.437266,0.712219,0.677833,0.703447,0.661562,0.760871,0.688437,0.009147,0.009317,2025-05-05 21:45:29.943213,1
23+
32768.000000,733.450281,729.424324,727.861837,726.286411,756.582488,737.136026,0.733450,0.729424,0.727862,0.726286,0.756582,0.737136,0.003869,0.002001,2025-05-05 21:45:29.943213,1
24+
"""
25+
26+
PERFORMANCE_CSVS = {
27+
"softmax": SOFTMAX_PERFORMANCE_CSV,
28+
}
29+
30+
31+
@pytest.mark.parametrize("command", ["run"])
32+
@pytest.mark.parametrize("benchmark", ["softmax"])
33+
@pytest.mark.parametrize("provider", ["triton", None])
34+
@pytest.mark.parametrize("n_runs", [None, 1, 2])
35+
@pytest.mark.parametrize("show_details", [False, True])
36+
@pytest.mark.parametrize("json_output", [False, True])
37+
def test_benchmark_run_monkeypatched(
38+
command: str,
39+
benchmark: str,
40+
provider: Optional[str],
41+
n_runs: Optional[int],
42+
show_details: bool,
43+
json_output: bool,
44+
capsys,
45+
):
46+
args = [command, benchmark]
47+
if provider:
48+
args.extend(["--provider", provider])
49+
if n_runs:
50+
args.extend(["--n_runs", str(n_runs)])
51+
if show_details:
52+
args.extend(["--show-details"])
53+
if json_output:
54+
args.extend(["--json"])
55+
56+
configs = BenchmarkConfigs.from_args(args)
57+
for config in configs.configs:
58+
config.res_df = pd.read_csv(io.StringIO(PERFORMANCE_CSVS[config.key]))
59+
configs.run()
60+
61+
captured_output = capsys.readouterr().out
62+
output_lines = captured_output.splitlines()
63+
if provider and not json_output:
64+
assert "Selected providers: {'triton': 'Triton'}" in output_lines
65+
# Check if the prettified result table have CV column, example - "metric GB/s GB/s TFlops TFlops CV CV"
66+
if show_details and not json_output:
67+
assert not show_details or re.search(r"^metric.* CV", captured_output, flags=re.MULTILINE)

benchmarks/triton_kernels_benchmark/benchmark_config_templates.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
gemm_tensor_of_ptr_benchmark,
88
flash_attention_benchmark,
99
flash_attention_tensor_desc_benchmark,
10+
prefix_sums,
1011
)
1112

1213
CONFIGS = [
@@ -29,14 +30,14 @@
2930
get_benchmark=gemm_tensor_of_ptr_benchmark.get_benchmark,
3031
run_opts={},
3132
categories={BenchmarkCategory.EXPERIMENTAL, BenchmarkCategory.GEMM},
32-
description="Triton GEMM kernel benchmark - with tensor of pointer",
33+
description="GEMM kernel benchmark - with tensor of pointer",
3334
),
3435
BenchmarkConfig(
3536
key="gemm-tensor-desc",
3637
get_benchmark=gemm_tensor_desc_benchmark.get_benchmark,
3738
run_opts={},
3839
categories={BenchmarkCategory.EXPERIMENTAL, BenchmarkCategory.GEMM},
39-
description="Triton GEMM kernel benchmark - with tensor descriptor",
40+
description="GEMM kernel benchmark - with tensor descriptor",
4041
),
4142
BenchmarkConfig(
4243
key="gemm_bt",
@@ -49,19 +50,36 @@
4950
key="gemm_at",
5051
get_benchmark=gemm_benchmark.get_benchmark,
5152
run_opts={"transpose_a": True},
52-
categories={BenchmarkCategory.EXPERIMENTAL, BenchmarkCategory.GEMM},
53+
categories={BenchmarkCategory.OPTIONAL, BenchmarkCategory.GEMM},
5354
description="Triton GEMM (A^t@B) kernel benchmark",
5455
),
5556
BenchmarkConfig(
5657
key="flash_attention",
5758
get_benchmark=flash_attention_benchmark.get_benchmark,
5859
run_opts={"fa_kernel_mode": "fwd"},
5960
categories={BenchmarkCategory.CORE, BenchmarkCategory.FLASH_ATTENTION},
61+
description="FlashAttention forward kernel benchmark",
6062
),
6163
BenchmarkConfig(
6264
key="flash_attention_tensor_desc",
6365
get_benchmark=flash_attention_tensor_desc_benchmark.get_benchmark,
6466
run_opts={"fa_kernel_mode": "fwd"},
6567
categories={BenchmarkCategory.EXPERIMENTAL, BenchmarkCategory.FLASH_ATTENTION},
6668
),
69+
BenchmarkConfig(
70+
key="flash_attention_bwd",
71+
get_benchmark=flash_attention_benchmark.get_benchmark,
72+
run_opts={"fa_kernel_mode": "bwd"},
73+
categories={BenchmarkCategory.OPTIONAL, BenchmarkCategory.FLASH_ATTENTION},
74+
description="FlashAttention backward kernel benchmark",
75+
),
76+
BenchmarkConfig(
77+
key="prefix-sums",
78+
get_benchmark=prefix_sums.get_benchmark,
79+
run_opts={},
80+
categories={BenchmarkCategory.OPTIONAL, BenchmarkCategory.PREFIX_SUMS},
81+
description="Prefix Sums kernel benchmark",
82+
),
83+
# FIXME: add optional - splitK, streamk, gemm with pre-op or postops, microbenchmarks
84+
# FIXME: Experimental - FlexAttention
6785
]

0 commit comments

Comments
 (0)