Skip to content

Commit e358bbc

Browse files
authored
Add --config_regex option to gemmbench to allow executing a subset (#72)
Essentially 759cd6f but for gemmbench, also intended for debugging.
1 parent f82c28c commit e358bbc

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

iree_kernel_benchmark/gemmbench/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def compile_gemm(
7575
help="Regular expression for allowed benchmark tags. Defaults to all tags allowed.",
7676
default=".*",
7777
)
78+
parser.add_argument(
79+
"--config_regex",
80+
help="Regular expression for allowed benchmark configurations. Defaults to all allowed.",
81+
default=".*",
82+
)
7883
parser.add_argument(
7984
"--roofline",
8085
help="Comma separated csv file list to generate roofline plot with",
@@ -130,6 +135,7 @@ def compile_gemm(
130135
requested_dtypes,
131136
requested_variants,
132137
args.tag_regex,
138+
args.config_regex,
133139
args.raw_accumulators,
134140
)
135141
print(f"Generated {len(configs)} gemm configs.")

iree_kernel_benchmark/gemmbench/problems.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,9 +1114,11 @@ def get_matching_configs(
11141114
dtypes: list[str],
11151115
variants: list[str],
11161116
tag_regex: str,
1117+
config_regex: str,
11171118
raw_accumulators: bool,
11181119
) -> list[tuple[str, GemmConfig]]:
11191120
tag_re = re.compile(tag_regex)
1121+
config_re = re.compile(config_regex)
11201122
matching_configs: list[tuple[str, GemmConfig]] = []
11211123
for tag, config in tagged_configs:
11221124
if config.operand_element_type not in dtypes:
@@ -1125,6 +1127,8 @@ def get_matching_configs(
11251127
continue
11261128
if not tag_re.match(tag):
11271129
continue
1130+
if not config_re.match(config.get_name()):
1131+
continue
11281132
# The raw_accumulators arg means "test configs where the result element
11291133
# type is different from what it would be in the default mode".
11301134
# We can't just test for (result_element_type == accumulator_element_type),

0 commit comments

Comments
 (0)