Skip to content

Commit c31afba

Browse files
authored
gemmbench: Add support for four FP8 types and test one in CI (#78)
The added types are ones supported by both MLIR and IREE, and with some AMD GPU support: f8E4M3FNUZ and f8E5M2FNUZ are supported on CDNA3, whereas f8E4M3FN and f8E5M2 are supported on RDNA4. f8E4M3FNUZ is tested in CI (on MI300).
1 parent 287c79d commit c31afba

File tree

5 files changed

+67
-19
lines changed

5 files changed

+67
-19
lines changed

.github/workflows/run_bench.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ jobs:
5656
source bench_venv/bin/activate
5757
python -m iree_kernel_benchmark.gemmbench --dtypes f16
5858
59+
- name: GEMM FP8 (f8E4M3FNUZ)
60+
run: |
61+
source bench_venv/bin/activate
62+
python -m iree_kernel_benchmark.gemmbench --dtypes f8E4M3FNUZ
63+
5964
- name: GEMM I8
6065
run: |
6166
source bench_venv/bin/activate

iree_kernel_benchmark/gemmbench/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def compile_gemm(
6262
"--dtypes",
6363
nargs="+",
6464
default=[],
65-
help="List of data types to generate benchmarks for. Defaults to f16. Other options include f32, bf16, i8.",
65+
help="List of data types to generate benchmarks for. Defaults to f16. Other options include (for example) f32, bf16, i8, f8E4M3FNUZ.",
6666
)
6767
parser.add_argument(
6868
"--raw_accumulators",

iree_kernel_benchmark/gemmbench/gemm_utils.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@
2525
from iree.compiler.dialects import arith, func, linalg, tensor
2626

2727

28+
def num_bytes(dtype: str) -> int:
29+
dtype_to_bytes = {
30+
"f32": 4,
31+
"f16": 2,
32+
"bf16": 2,
33+
"f8E4M3FNUZ": 1,
34+
"f8E5M2FNUZ": 1,
35+
"f8E4M3FN": 1,
36+
"f8E5M2": 1,
37+
"i8": 1,
38+
"i32": 4,
39+
}
40+
return dtype_to_bytes[dtype]
41+
42+
2843
@dataclass
2944
class GemmConfig:
3045
M: int
@@ -58,16 +73,8 @@ def get_out(self) -> str:
5873
return f"{self.M}x{self.N}x{self.result_element_type}"
5974

6075
def get_byte_count(self) -> int:
61-
dtype_to_bytes = {
62-
"f32": 4,
63-
"f16": 2,
64-
"bf16": 2,
65-
"f8E4M3FNUZ": 1,
66-
"i8": 1,
67-
"i32": 4,
68-
}
69-
operand_bytes_per_element = dtype_to_bytes[self.operand_element_type]
70-
result_bytes_per_element = dtype_to_bytes[self.result_element_type]
76+
operand_bytes_per_element = num_bytes(self.operand_element_type)
77+
result_bytes_per_element = num_bytes(self.result_element_type)
7178
byte_count_input = (self.M + self.N) * self.K * operand_bytes_per_element
7279
byte_count_output = (self.M * self.N) * result_bytes_per_element
7380
return byte_count_input + byte_count_output
@@ -83,6 +90,10 @@ def _convert_dtype_to_mlir(dtype: str) -> ir.Type:
8390
"i16": lambda: ir.IntegerType.get_signless(16),
8491
"i32": lambda: ir.IntegerType.get_signless(32),
8592
"i64": lambda: ir.IntegerType.get_signless(64),
93+
"f8E4M3FNUZ": lambda: ir.Float8E4M3FNUZType.get(),
94+
"f8E5M2FNUZ": lambda: ir.Float8E5M2FNUZType.get(),
95+
"f8E4M3FN": lambda: ir.Float8E4M3FNType.get(),
96+
"f8E5M2": lambda: ir.Float8E5M2Type.get(),
8697
"f16": lambda: ir.F16Type.get(),
8798
"f32": lambda: ir.F32Type.get(),
8899
"f64": lambda: ir.F64Type.get(),

iree_kernel_benchmark/gemmbench/problems.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from .gemm_utils import GemmConfig
7+
from .gemm_utils import GemmConfig, num_bytes
88

99
import re
1010

1111

12-
def num_bytes(dtype: str) -> int:
13-
return {"f16": 2, "bf16": 2, "f32": 4, "i8": 1, "i32": 4}[dtype]
14-
15-
1612
def get_default_accumulator_element_type(operand_element_type: str) -> str:
17-
return {"f16": "f32", "bf16": "f32", "f32": "f32", "i8": "i32", "i32": "i32"}[
18-
operand_element_type
19-
]
13+
return {
14+
"f16": "f32",
15+
"bf16": "f32",
16+
"f32": "f32",
17+
"f8E4M3FNUZ": "f32",
18+
"f8E5M2FNUZ": "f32",
19+
"f8E4M3FN": "f32",
20+
"f8E5M2": "f32",
21+
"i8": "i32",
22+
"i32": "i32",
23+
}[operand_element_type]
2024

2125

2226
def get_default_result_element_type(

tests/test_gemmbench_mlir_gen.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,34 @@ def test_n_t_f16_f32_f16():
4141
)
4242

4343

44+
def test_n_t_f8_f32_f8():
45+
# From 'llama8b_prefill' (f8 version is synthetic)
46+
cfg = GemmConfig(
47+
M=512,
48+
N=4096,
49+
K=14336,
50+
tA="N",
51+
tB="T",
52+
operand_element_type="f8E4M3FNUZ",
53+
accumulator_element_type="f32",
54+
result_element_type="f8E4M3FNUZ",
55+
)
56+
mlir = generate_mlir(cfg)
57+
match_lines(
58+
mlir,
59+
[
60+
"module {",
61+
"func.func @main(%arg0: tensor<512x14336xf8E4M3FNUZ>, %arg1: tensor<4096x14336xf8E4M3FNUZ>) -> tensor<512x4096xf8E4M3FNUZ> {",
62+
"%cst = arith.constant 0.000000e+00 : f32",
63+
"%0 = tensor.empty() : tensor<512x4096xf32>",
64+
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
65+
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<512x14336xf8E4M3FNUZ>, tensor<4096x14336xf8E4M3FNUZ>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
66+
"%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf8E4M3FNUZ>",
67+
"return %3 : tensor<512x4096xf8E4M3FNUZ>",
68+
],
69+
)
70+
71+
4472
def test_n_t_bf16_f32_bf16():
4573
# From 'llama70bmemory'
4674
cfg = GemmConfig(

0 commit comments

Comments
 (0)