Skip to content

Commit af062d3

Browse files
authored
gemmbench: Generate benchmarks from supplied dtypes instead of filtering (#76)
Prior to this commit, gemmbench's problems.py hardcoded a set of datatypes for each benchmark, generating all possibilities ahead of time. The --dtypes and --raw_accumulators command-line arguments were then used to filter the resulting set of benchmarks. This commit refactors problems.py to have less redundancy and to not hardcode any datatypes. Now only the shapes are hardcoded, and the --dtypes and --raw_accumulators command-line arguments control which datatypes are used in the result. One consequence of this is that now we get the same set of shapes no matter which datatype is requested, which provides more thorough testing for types like i8 that previously only had a very small number of hardcoded shapes. To avoid an explosion in the amount of tests, the default set of datatypes is changed to just f16. CI is changed to run f16 and i8 (in separate job steps for better visibility); note that bf16 CI coverage is removed. Also, the padded LLaMA shapes are removed.
1 parent 69fe7f5 commit af062d3

File tree

4 files changed

+187
-311
lines changed

4 files changed

+187
-311
lines changed

.github/workflows/run_bench.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,20 @@ jobs:
4646
source bench_venv/bin/activate
4747
python -m iree_kernel_benchmark.attentionbench
4848
49-
- name: TK GEMM
49+
- name: TK GEMM FP16
5050
run: |
5151
source bench_venv/bin/activate
52-
python -m iree_kernel_benchmark.gemmbench --tk
52+
python -m iree_kernel_benchmark.gemmbench --tk --dtypes f16
5353
54-
- name: GEMM
54+
- name: GEMM FP16
5555
run: |
5656
source bench_venv/bin/activate
57-
python -m iree_kernel_benchmark.gemmbench
57+
python -m iree_kernel_benchmark.gemmbench --dtypes f16
58+
59+
- name: GEMM I8
60+
run: |
61+
source bench_venv/bin/activate
62+
python -m iree_kernel_benchmark.gemmbench --dtypes i8
5863
5964
- name: Roofline Plots
6065
run: |

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,16 @@ python -m iree_kernel_benchmark.convbench --tk
6060
python -m iree_kernel_benchmark.gemmbench
6161
```
6262

63+
This will only generate FP16 benchmarks. You may want to specify a different set of types with `--dtypes`, e.g. `--dtypes i8 bf16`.
64+
6365
### TK GEMM Benchmarking
6466

6567
```
6668
python -m iree_kernel_benchmark.gemmbench --tk
6769
```
6870

71+
Same remark about types applies.
72+
6973
### Attention Benchmarking
7074

7175
```

iree_kernel_benchmark/gemmbench/__main__.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,18 @@ def compile_gemm(
6262
"--dtypes",
6363
nargs="+",
6464
default=[],
65-
help="List of data types to benchmark. Defaults to all supported types.",
65+
help="List of data types to generate benchmarks for. Defaults to f16. Other options include f32, bf16, i8.",
66+
)
67+
parser.add_argument(
68+
"--raw_accumulators",
69+
action="store_true",
70+
help="If true, generate benchmark matmuls returning the raw accumulator type with no truncation. If false (default), generate benchmark matmuls where results are truncated and cast to the input element type.",
6671
)
6772
parser.add_argument(
6873
"--variants",
6974
nargs="+",
7075
default=[],
71-
help="List of matmul variants to benchmark. Default to all variants: NN, NT, TN, and TT.",
76+
help="List of matmul variants to filter benchmarks by. Default to all variants: NN, NT, TN, and TT.",
7277
)
7378
parser.add_argument(
7479
"--tag_regex",
@@ -102,15 +107,10 @@ def compile_gemm(
102107
default=None,
103108
help="Directory to which executable files will be dumped.",
104109
)
105-
parser.add_argument(
106-
"--raw_accumulators",
107-
action="store_true",
108-
help="If true, benchmark matmuls returning the raw accumulator type with no truncation. If false (default), the results are truncated and cast to the input element type.",
109-
)
110110

111111
args = parser.parse_args()
112112
# Handle default values here, since list args are not compatible with defaulted lists.
113-
requested_dtypes = ["f16", "bf16", "i8"] if not args.dtypes else list(args.dtypes)
113+
requested_dtypes = ["f16"] if not args.dtypes else list(args.dtypes)
114114
requested_variants = (
115115
["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants)
116116
)
@@ -129,14 +129,18 @@ def compile_gemm(
129129
sys.exit()
130130

131131
tk = args.tk
132-
configs = get_tk_gemm_configs() if tk else get_gemm_configs()
132+
configs = []
133+
for dtype in requested_dtypes:
134+
configs += (
135+
get_tk_gemm_configs(dtype, args.raw_accumulators)
136+
if tk
137+
else get_gemm_configs(dtype, args.raw_accumulators)
138+
)
133139
configs = get_matching_configs(
134140
configs,
135-
requested_dtypes,
136141
requested_variants,
137142
args.tag_regex,
138143
args.config_regex,
139-
args.raw_accumulators,
140144
)
141145
print(f"Generated {len(configs)} gemm configs.")
142146

0 commit comments

Comments
 (0)