Skip to content

Commit 80bbaf8

Browse files
authored
[TRITON-BENCHMARKS] Add shape filter support (#4147)
1 parent db6adfb commit 80bbaf8

File tree

11 files changed

+462
-124
lines changed

11 files changed

+462
-124
lines changed

benchmarks/setup.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,17 @@ def get_git_commit_hash(length=8):
122122
"matplotlib",
123123
],
124124
package_dir={"triton_kernels_benchmark": "triton_kernels_benchmark"},
125-
package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]},
125+
package_data={"triton_kernels_benchmark": [
126+
"xetla_kernel.cpython-*.so",
127+
"cutlass_kernel.cpython-*.so",
128+
]},
126129
cmdclass={
127130
"build_ext": build_ext,
128131
},
129-
ext_modules=[CMakeExtension("triton_kernels_benchmark.xetla_kernel")],
132+
ext_modules=[
133+
CMakeExtension("triton_kernels_benchmark.xetla_kernel"),
134+
CMakeExtension("triton_kernels_benchmark.cutlass_kernel"),
135+
],
130136
entry_points={
131137
"console_scripts": [
132138
"triton-benchmarks = triton_kernels_benchmark.benchmark_utils:main",

benchmarks/tests/test_entry_point.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import pytest
44

55
from triton_kernels_benchmark.benchmark_testing import MarkArgs
6-
from triton_kernels_benchmark.becnhmark_config_templates import CONFIGS
76
from triton_kernels_benchmark.benchmark_utils import BenchmarkCategory, BenchmarkConfigs
87

98
ALL_CATEGORIES = {cat.value for cat in BenchmarkCategory}
10-
ALL_CONFIGS = {config.key: config for config in CONFIGS}
9+
ALL_CONFIGS = set(BenchmarkConfigs._get_all_configs().keys()) # pylint: disable=W0212
1110

1211

1312
@pytest.mark.parametrize(
@@ -22,7 +21,7 @@
2221
"providers_count",
2322
),
2423
(
25-
[True, ALL_CONFIGS, True, ALL_CATEGORIES, [], None, lambda x: x > 1, lambda x: x > 1],
24+
[True, ALL_CONFIGS, True, ALL_CATEGORIES, [], None, lambda x: x > 1, lambda x: x >= 1],
2625
[True, {"softmax", "gemm"}, True, ALL_CATEGORIES, [], None, lambda x: x > 1, lambda x: x > 1],
2726
[True, {"softmax", "gemm"}, True, {"core", "gemm", "softmax"}, [], None, lambda x: x > 1, lambda x: x > 1],
2827
[False, {"softmax"}, False, {"optional"}, ["triton"], AssertionError, None, None],
@@ -49,6 +48,7 @@ def benchmark_configs():
4948
select_all=select_all,
5049
categories_filter=categories_filter,
5150
providers_filter=providers_filter,
51+
shape_pattern=None,
5252
json_output=False,
5353
detailed_output=False,
5454
tag="",

benchmarks/tests/test_mocks.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import Optional
1+
from typing import List, Optional
22

33
import re
4-
54
import io
65

76
import pytest
@@ -28,34 +27,52 @@
2827
}
2928

3029

30+
@pytest.fixture(autouse=True)
31+
def capture_hw_details_env(monkeypatch):
32+
monkeypatch.setenv("GPU_DEVICE", "Intel(R) Data Center GPU Max 1100")
33+
yield
34+
35+
3136
@pytest.mark.parametrize("command", ["run"])
32-
@pytest.mark.parametrize("benchmark", ["softmax"])
37+
@pytest.mark.parametrize(
38+
"benchmark",
39+
[
40+
["softmax"],
41+
["softmax", "--shape-pattern", "[256]"],
42+
["softmax", "--shape-pattern", "[*]"],
43+
],
44+
)
3345
@pytest.mark.parametrize("provider", ["triton", None])
34-
@pytest.mark.parametrize("n_runs", [None, 1, 2])
46+
@pytest.mark.parametrize("n_runs", [1, 2, 3])
3547
@pytest.mark.parametrize("show_details", [False, True])
3648
@pytest.mark.parametrize("json_output", [False, True])
49+
@pytest.mark.parametrize("reports", [False, True])
3750
def test_benchmark_run_monkeypatched(
3851
command: str,
39-
benchmark: str,
52+
benchmark: List[str],
4053
provider: Optional[str],
41-
n_runs: Optional[int],
54+
n_runs: int,
4255
show_details: bool,
4356
json_output: bool,
57+
reports: bool,
4458
capsys,
59+
tmp_path,
4560
):
46-
args = [command, benchmark]
61+
args = [command] + benchmark
4762
if provider:
4863
args.extend(["--provider", provider])
49-
if n_runs:
64+
if n_runs > 1:
5065
args.extend(["--n_runs", str(n_runs)])
5166
if show_details:
5267
args.extend(["--show-details"])
5368
if json_output:
5469
args.extend(["--json"])
70+
if reports:
71+
args.extend(["--reports", str(tmp_path)])
5572

5673
configs = BenchmarkConfigs.from_args(args)
5774
for config in configs.configs:
58-
config.res_df = pd.read_csv(io.StringIO(PERFORMANCE_CSVS[config.key]))
75+
config.res_df_list = [pd.read_csv(io.StringIO(PERFORMANCE_CSVS[config.key]))] * n_runs
5976
configs.run()
6077

6178
captured_output = capsys.readouterr().out
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import pytest
2+
3+
from triton_kernels_benchmark.benchmark_utils import BenchmarkCategory, BenchmarkConfigs
4+
from triton_kernels_benchmark.benchmark_shapes_parser import ShapePatternParser
5+
6+
CONFIG_SHAPES = [
7+
shape for config in BenchmarkConfigs._get_configs( # pylint: disable=W0212
8+
configs_filter=BenchmarkConfigs._get_all_configs().keys(), # pylint: disable=W0212
9+
categories_filter=[cat.value for cat in BenchmarkCategory],
10+
providers_filter=[],
11+
shape_pattern=None,
12+
) for shape in config.shapes
13+
]
14+
15+
16+
@pytest.mark.parametrize(
17+
"input_string, expected",
18+
[
19+
("[1-2-3]", [1, 2, 3]),
20+
("[16-32-1024-64]", [16, 32, 1024, 64]),
21+
("[True-False-123]", ["True", "False", 123]),
22+
],
23+
)
24+
def test_parse_valid_shapes(input_string, expected):
25+
assert ShapePatternParser.parse(input_string) == expected
26+
27+
28+
@pytest.mark.parametrize("known_shape", CONFIG_SHAPES)
29+
def test_parse_all_known_shapes(known_shape):
30+
ShapePatternParser.parse(known_shape)
31+
32+
33+
@pytest.mark.parametrize(
34+
"pattern, tokens",
35+
[
36+
("[16-*-1024-*]", [16, "*", 1024, "*"]),
37+
],
38+
)
39+
def test_parse_valid_pattern_shapes(pattern, tokens):
40+
parsed_tokens = ShapePatternParser.parse(pattern, pattern_shape=True)
41+
assert tokens == parsed_tokens
42+
43+
44+
def test_parse_bracketed_star_without_pattern_shape_fails():
45+
with pytest.raises(ValueError):
46+
ShapePatternParser.parse("[*]")
47+
48+
49+
@pytest.mark.parametrize(
50+
"invalid_string",
51+
[
52+
"1-2-3",
53+
"[]",
54+
"[ ]",
55+
"[ - ]",
56+
"[-]",
57+
"[--]",
58+
"[- -]",
59+
"[1-2-@-4]",
60+
],
61+
)
62+
def test_parse_bracketed_invalid(invalid_string):
63+
with pytest.raises(ValueError):
64+
ShapePatternParser.parse(invalid_string, pattern_shape=True)
65+
with pytest.raises(ValueError):
66+
ShapePatternParser(invalid_string)
67+
68+
69+
@pytest.mark.parametrize(
70+
"pattern",
71+
[
72+
"[1-2-3]",
73+
"[a-b-c]",
74+
"[*-b-3]",
75+
],
76+
)
77+
def test_init_valid(pattern):
78+
parser = ShapePatternParser(pattern)
79+
assert parser.pattern == pattern
80+
assert len(parser.pattern_tokens) == parser.pattern_dims
81+
82+
83+
@pytest.fixture
84+
def sample_shapes():
85+
return [
86+
"[16-32-1024-64-False-bwd]",
87+
"[16-64-1024-32-True-fwd]",
88+
"[32-32-512-64-False-bwd]",
89+
"[16-32-1024-64-False-fwd]",
90+
]
91+
92+
93+
def test_filter_exact_match(sample_shapes): # pylint: disable=W0621
94+
parser = ShapePatternParser("[16-32-1024-64-False-bwd]")
95+
assert parser.filter_by_pattern(sample_shapes) == ["[16-32-1024-64-False-bwd]"]
96+
97+
98+
def test_filter_wildcard_match(sample_shapes): # pylint: disable=W0621
99+
parser = ShapePatternParser("[16-*-1024-64-False-*]")
100+
expected = [
101+
"[16-32-1024-64-False-bwd]",
102+
"[16-32-1024-64-False-fwd]",
103+
]
104+
assert parser.filter_by_pattern(sample_shapes) == expected
105+
106+
107+
@pytest.mark.parametrize(
108+
"pattern, shapes",
109+
[
110+
("[*]", ["bad-shape", "[foo]", "[123]"]),
111+
],
112+
)
113+
def test_filter_invalid_shape_strings(pattern, shapes):
114+
parser = ShapePatternParser(pattern)
115+
with pytest.raises(ValueError):
116+
parser.filter_by_pattern(shapes)
117+
118+
119+
@pytest.mark.parametrize(
120+
"pattern, shapes",
121+
[
122+
("[1-2-3]", ["[1-2]"]),
123+
],
124+
)
125+
def test_filter_dims_mismatch(pattern, shapes):
126+
parser = ShapePatternParser(pattern)
127+
with pytest.raises(ValueError) as exc_info:
128+
parser.filter_by_pattern(shapes)
129+
assert "mismatch" in str(exc_info.value)
130+
131+
132+
@pytest.mark.parametrize(
133+
"pattern, shapes",
134+
[
135+
("[*]", ["[foo]", "[bar]", "[123]"]),
136+
],
137+
)
138+
def test_filter_single_wildcard_matches_all(pattern, shapes):
139+
parser = ShapePatternParser(pattern)
140+
assert parser.filter_by_pattern(shapes) == shapes

benchmarks/triton_kernels_benchmark/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
BENCHMARKING_METHOD,
1212
)
1313

14+
from .benchmark_shapes_parser import ShapePatternParser
15+
1416
if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
1517
os.environ["INJECT_PYTORCH"] = "True"
1618

@@ -23,4 +25,5 @@
2325
"BenchmarkCategory",
2426
"BenchmarkConfig",
2527
"BENCHMARKING_METHOD",
28+
"ShapePatternParser",
2629
]
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import List, Union
2+
from dataclasses import dataclass, field
3+
4+
import re
5+
6+
7+
@dataclass
8+
class ShapePatternParser:
9+
pattern: str
10+
pattern_tokens: List[str] = field(init=False)
11+
pattern_dims: int = field(init=False)
12+
13+
def __post_init__(self):
14+
self.pattern_tokens = self.parse(self.pattern, pattern_shape=True)
15+
self.pattern_dims = len(self.pattern_tokens)
16+
17+
def __str__(self):
18+
return self.pattern
19+
20+
@staticmethod
21+
def parse(shape_string: str, pattern_shape: bool = False) -> List[Union[int, str]]:
22+
pattern_match = re.fullmatch(r"\[(.*)\]", shape_string)
23+
if not pattern_match:
24+
raise ValueError(
25+
f"Invalid format: {shape_string!r}, only patterns similar to [16-*-1024-*-bwd] are supported", )
26+
inner_string = pattern_match.group(1)
27+
if not inner_string:
28+
raise ValueError(f"Empty shape - {inner_string}")
29+
tokens = inner_string.split("-")
30+
if not any(tokens):
31+
raise ValueError(f"Empty shape - {inner_string}")
32+
result: List[Union[int, str]] = []
33+
for token in tokens:
34+
try:
35+
result.append(int(token))
36+
except ValueError:
37+
if token.isalnum() or token == "*" and pattern_shape:
38+
result.append(token)
39+
else:
40+
raise ValueError( # pylint: disable=W0707
41+
f"Unsupported shape or shape pattern {shape_string}"
42+
"Each shape element could be either int, alphanumeric string or '*' in shape pattern")
43+
return result
44+
45+
def matches_pattern(self, shape_string: str) -> bool:
46+
tokens = self.parse(shape_string)
47+
shape_dims = len(tokens)
48+
if shape_dims != self.pattern_dims:
49+
raise ValueError(f"Input shape dims {shape_dims} and pattern shape dims {self.pattern_dims} mismatch")
50+
for pattern_token, token in zip(self.pattern_tokens, tokens):
51+
if pattern_token == "*":
52+
continue
53+
if pattern_token != token:
54+
break
55+
else:
56+
return True
57+
return False
58+
59+
def __call__(self, shape_string: str) -> bool:
60+
return self.matches_pattern(shape_string)
61+
62+
def filter_by_pattern(self, shape_strings: List[str]) -> List[str]:
63+
return [shape_string for shape_string in shape_strings if self.matches_pattern(shape_string)]

0 commit comments

Comments
 (0)