You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
gemmbench: Generate benchmarks from supplied dtypes instead of filtering (#76)
Prior to this commit, gemmbench's problems.py hardcoded a set of
datatypes for each benchmark, generating all possibilities ahead of
time. The --dtypes and --raw_accumulators command-line arguments were
then used to filter the resulting set of benchmarks.
This commit refactors problems.py to have less redundancy and to not
hardcode any datatypes. Now only the shapes are hardcoded, and the
--dtypes and --raw_accumulators command-line arguments control which
datatypes are used in the result. One consequence of this is that now we
get the same set of shapes no matter which datatype is requested, which
provides more thorough testing for types like i8 that previously only
had a very small number of hardcoded shapes.
To avoid an explosion in the amount of tests, the default set of
datatypes is changed to just f16. CI is changed to run f16 and i8 (in
separate job steps for better visibility); note that bf16 CI coverage is
removed.
Also, the padded LLaMA shapes are removed.
Copy file name to clipboardExpand all lines: iree_kernel_benchmark/gemmbench/__main__.py
+15-11Lines changed: 15 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -62,13 +62,18 @@ def compile_gemm(
62
62
"--dtypes",
63
63
nargs="+",
64
64
default=[],
65
-
help="List of data types to benchmark. Defaults to all supported types.",
65
+
help="List of data types to generate benchmarks for. Defaults to f16. Other options include f32, bf16, i8.",
66
+
)
67
+
parser.add_argument(
68
+
"--raw_accumulators",
69
+
action="store_true",
70
+
help="If true, generate benchmark matmuls returning the raw accumulator type with no truncation. If false (default), generate benchmark matmuls where results are truncated and cast to the input element type.",
66
71
)
67
72
parser.add_argument(
68
73
"--variants",
69
74
nargs="+",
70
75
default=[],
71
-
help="List of matmul variants to benchmark. Default to all variants: NN, NT, TN, and TT.",
76
+
help="List of matmul variants to filter benchmarks by. Default to all variants: NN, NT, TN, and TT.",
72
77
)
73
78
parser.add_argument(
74
79
"--tag_regex",
@@ -102,15 +107,10 @@ def compile_gemm(
102
107
default=None,
103
108
help="Directory to which executable files will be dumped.",
104
109
)
105
-
parser.add_argument(
106
-
"--raw_accumulators",
107
-
action="store_true",
108
-
help="If true, benchmark matmuls returning the raw accumulator type with no truncation. If false (default), the results are truncated and cast to the input element type.",
109
-
)
110
110
111
111
args=parser.parse_args()
112
112
# Handle default values here, since list args are not compatible with defaulted lists.
0 commit comments