Skip to content

Commit f695efc

Browse files
authored
gemmbench: Add support for dynamic dims, use for llama8b_prefill (#80)
Dynamic dims seem to be appropriate for llama8b_prefill based on samples from https://github.com/nod-ai/playbook/issues/63
1 parent 1ca0654 commit f695efc

File tree

4 files changed

+196
-14
lines changed

4 files changed

+196
-14
lines changed

iree_kernel_benchmark/gemmbench/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def compile_gemm(
252252
config.operand_element_type,
253253
config.tA,
254254
config.tB,
255+
f"D={config.runtime_dim}" if config.runtime_dim is not None else "",
255256
round(benchmark_gemm_mean_time_us, 4),
256257
round(arithmetic_intensity, 4),
257258
round(tflops_per_second, 4),
@@ -271,6 +272,7 @@ def compile_gemm(
271272
"dtype",
272273
"tA",
273274
"tB",
275+
"runtime_dim",
274276
"mean_microseconds",
275277
"arithmetic_intensity",
276278
"tflops",

iree_kernel_benchmark/gemmbench/gemm_utils.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3-
from typing import Optional
3+
from typing import Optional, Tuple
44

55
try:
66
import iree.turbine.kernel as tk
@@ -24,6 +24,8 @@
2424
from iree.compiler import ir
2525
from iree.compiler.dialects import arith, func, linalg, tensor
2626

27+
kDynamic = ir.ShapedType.get_dynamic_size()
28+
2729

2830
def num_bytes(dtype: str) -> int:
2931
dtype_to_bytes = {
@@ -42,6 +44,7 @@ def num_bytes(dtype: str) -> int:
4244

4345
@dataclass
4446
class GemmConfig:
47+
# Note that M, N and K may be set to kDynamic, a special value
4548
M: int
4649
N: int
4750
K: int
@@ -50,37 +53,62 @@ class GemmConfig:
5053
operand_element_type: str
5154
accumulator_element_type: str
5255
result_element_type: str
56+
# runtime_dim subtitutes for any dynamic dims when executing.
57+
# TODO: It would be better if we could execute the same compiled dynamic
58+
# kernel for a series of different sizes, rather than duplicating the
59+
# GemmConfig. The current design's advantage is that no changes have
60+
# to be made to the execution logic (looks just like a static shape).
61+
runtime_dim: Optional[int] = None
5362

5463
def get_name(self) -> str:
55-
name = f"gemm_{self.M}_{self.N}_{self.K}_{self.operand_element_type}_{self.accumulator_element_type}"
64+
M = self.M if self.M != kDynamic else "D"
65+
N = self.N if self.N != kDynamic else "D"
66+
K = self.K if self.K != kDynamic else "D"
67+
name = f"gemm_{M}_{N}_{K}_{self.operand_element_type}_{self.accumulator_element_type}"
5668
if self.tA == "T":
5769
name += "_tA"
5870
elif self.tB == "T":
5971
name += "_tB"
72+
if self.runtime_dim is not None:
73+
name += f"_D={self.runtime_dim}"
6074
return name
6175

76+
def get_runtime_dims(self) -> Tuple[int, int, int]:
77+
"""
78+
Get concrete dims to use when executing this kernel.
79+
"""
80+
M = self.M if self.M != kDynamic else self.runtime_dim
81+
N = self.N if self.N != kDynamic else self.runtime_dim
82+
K = self.K if self.K != kDynamic else self.runtime_dim
83+
return M, N, K
84+
6285
def get_inp1(self) -> str:
86+
M, N, K = self.get_runtime_dims()
6387
if self.tA == "T":
64-
return f"{self.K}x{self.M}x{self.operand_element_type}"
65-
return f"{self.M}x{self.K}x{self.operand_element_type}"
88+
return f"{K}x{M}x{self.operand_element_type}"
89+
return f"{M}x{K}x{self.operand_element_type}"
6690

6791
def get_inp2(self) -> str:
92+
M, N, K = self.get_runtime_dims()
6893
if self.tB == "T":
69-
return f"{self.N}x{self.K}x{self.operand_element_type}"
70-
return f"{self.K}x{self.N}x{self.operand_element_type}"
94+
return f"{N}x{K}x{self.operand_element_type}"
95+
return f"{K}x{N}x{self.operand_element_type}"
7196

7297
def get_out(self) -> str:
73-
return f"{self.M}x{self.N}x{self.result_element_type}"
98+
M, N, K = self.get_runtime_dims()
99+
return f"{M}x{N}x{self.result_element_type}"
74100

75101
def get_byte_count(self) -> int:
76102
operand_bytes_per_element = num_bytes(self.operand_element_type)
77103
result_bytes_per_element = num_bytes(self.result_element_type)
78-
byte_count_input = (self.M + self.N) * self.K * operand_bytes_per_element
79-
byte_count_output = (self.M * self.N) * result_bytes_per_element
104+
M, N, K = self.get_runtime_dims()
105+
byte_count_input = (M + N) * K * operand_bytes_per_element
106+
byte_count_output = (M * N) * result_bytes_per_element
80107
return byte_count_input + byte_count_output
81108

82109
def get_flops(self) -> int:
83-
flops = 2 * self.M * self.N * self.K
110+
M, N, K = self.get_runtime_dims()
111+
flops = 2 * M * N * K
84112
return flops
85113

86114

@@ -123,16 +151,22 @@ def generate_mlir(config: GemmConfig):
123151
# Transpose A
124152
if tA == "T":
125153
arg0_type = ir.RankedTensorType.get([K, M], operand_element_type)
154+
arg0_M_idx = 1
126155
arg1_type = ir.RankedTensorType.get([K, N], operand_element_type)
156+
arg1_N_idx = 1
127157
# Transpose B
128158
elif tB == "T":
129159
arg0_type = ir.RankedTensorType.get([M, K], operand_element_type)
160+
arg0_M_idx = 0
130161
arg1_type = ir.RankedTensorType.get([N, K], operand_element_type)
162+
arg1_N_idx = 0
131163
# "Normal" path (can't transpose both)
132164
else:
133165
assert tA == "N" and tB == "N"
134166
arg0_type = ir.RankedTensorType.get([M, K], operand_element_type)
167+
arg0_M_idx = 0
135168
arg1_type = ir.RankedTensorType.get([K, N], operand_element_type)
169+
arg1_N_idx = 1
136170
result_type = ir.RankedTensorType.get([M, N], result_element_type)
137171

138172
module = ir.Module.create()
@@ -143,7 +177,24 @@ def main(arg0, arg1):
143177
zero_element = arith.constant(
144178
value=literal_zero, result=acc_element_type
145179
)
146-
empty_tensor = tensor.empty(element_type=acc_element_type, sizes=[M, N])
180+
if M == kDynamic:
181+
M_dynamic_dim_idx = arith.constant(
182+
value=arg0_M_idx, result=ir.IndexType.get()
183+
)
184+
M_dynamic_dim = tensor.dim(arg0, M_dynamic_dim_idx)
185+
if N == kDynamic:
186+
N_dynamic_dim_idx = arith.constant(
187+
value=arg1_N_idx, result=ir.IndexType.get()
188+
)
189+
N_dynamic_dim = tensor.dim(arg1, N_dynamic_dim_idx)
190+
191+
empty_tensor = tensor.empty(
192+
element_type=acc_element_type,
193+
sizes=[
194+
M_dynamic_dim if M == kDynamic else M,
195+
N_dynamic_dim if N == kDynamic else N,
196+
],
197+
)
147198
filled_tensor = linalg.fill(zero_element, outs=[empty_tensor])
148199

149200
if tA == "T":

iree_kernel_benchmark/gemmbench/problems.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from .gemm_utils import GemmConfig, num_bytes
7+
from .gemm_utils import GemmConfig, num_bytes, kDynamic
88

99
import re
1010

@@ -704,14 +704,15 @@ def llama8b_prefill(dtype: str, raw_accumulators: bool) -> list[GemmConfig]:
704704
if model == "8b_prefill":
705705
configs.append(
706706
GemmConfig(
707-
m,
707+
kDynamic,
708708
n,
709709
k,
710710
"N",
711711
"T",
712712
dtype,
713713
get_default_accumulator_element_type(dtype),
714714
get_default_result_element_type(dtype, raw_accumulators),
715+
runtime_dim=m,
715716
)
716717
)
717718
return configs

tests/test_gemmbench_mlir_gen.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from iree_kernel_benchmark.gemmbench.gemm_utils import GemmConfig, generate_mlir
1+
from iree_kernel_benchmark.gemmbench.gemm_utils import (
2+
GemmConfig,
3+
generate_mlir,
4+
kDynamic,
5+
)
26
from .utils import match_lines
37
from iree.compiler import ir
48
import pytest
@@ -69,6 +73,130 @@ def test_n_t_f8_f32_f8():
6973
)
7074

7175

76+
def test_n_t_f16_f32_f16_dynamic_dim_M():
77+
# From 'llama8b_prefill'
78+
cfg = GemmConfig(
79+
M=kDynamic,
80+
N=14336,
81+
K=4096,
82+
tA="N",
83+
tB="T",
84+
operand_element_type="f16",
85+
accumulator_element_type="f32",
86+
result_element_type="f16",
87+
runtime_dim=512, # Unused, included for correctness
88+
)
89+
mlir = generate_mlir(cfg)
90+
match_lines(
91+
mlir,
92+
[
93+
"module {",
94+
"func.func @main(%arg0: tensor<?x4096xf16>, %arg1: tensor<14336x4096xf16>) -> tensor<?x14336xf16> {",
95+
"%cst = arith.constant 0.000000e+00 : f32",
96+
"%c0 = arith.constant 0 : index",
97+
"%dim = tensor.dim %arg0, %c0 : tensor<?x4096xf16>",
98+
"%0 = tensor.empty(%dim) : tensor<?x14336xf32>",
99+
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x14336xf32>) -> tensor<?x14336xf32>",
100+
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<?x4096xf16>, tensor<14336x4096xf16>) outs(%1 : tensor<?x14336xf32>) -> tensor<?x14336xf32>",
101+
"%3 = arith.truncf %2 : tensor<?x14336xf32> to tensor<?x14336xf16>",
102+
"return %3 : tensor<?x14336xf16>",
103+
],
104+
)
105+
106+
107+
def test_t_n_f16_f32_f16_dynamic_dim_N():
108+
# Synthetic example (modified from test_n_t_f16_f32_f16_dynamic_dim_M)
109+
cfg = GemmConfig(
110+
M=512,
111+
N=kDynamic,
112+
K=4096,
113+
tA="T",
114+
tB="N",
115+
operand_element_type="f16",
116+
accumulator_element_type="f32",
117+
result_element_type="f16",
118+
runtime_dim=14366, # Unused, included for correctness
119+
)
120+
mlir = generate_mlir(cfg)
121+
match_lines(
122+
mlir,
123+
[
124+
"module {",
125+
"func.func @main(%arg0: tensor<4096x512xf16>, %arg1: tensor<4096x?xf16>) -> tensor<512x?xf16> {",
126+
"%cst = arith.constant 0.000000e+00 : f32",
127+
"%c1 = arith.constant 1 : index",
128+
"%dim = tensor.dim %arg1, %c1 : tensor<4096x?xf16>",
129+
"%0 = tensor.empty(%dim) : tensor<512x?xf32>",
130+
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x?xf32>) -> tensor<512x?xf32>",
131+
"%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<4096x512xf16>, tensor<4096x?xf16>) outs(%1 : tensor<512x?xf32>) -> tensor<512x?xf32>",
132+
"%3 = arith.truncf %2 : tensor<512x?xf32> to tensor<512x?xf16>",
133+
"return %3 : tensor<512x?xf16>",
134+
],
135+
)
136+
137+
138+
def test_n_n_f16_f32_f16_dynamic_dim_K():
139+
# Synthetic example (modified from test_n_t_f16_f32_f16_dynamic_dim_M)
140+
cfg = GemmConfig(
141+
M=512,
142+
N=14366,
143+
K=kDynamic,
144+
tA="N",
145+
tB="N",
146+
operand_element_type="f16",
147+
accumulator_element_type="f32",
148+
result_element_type="f16",
149+
runtime_dim=4096, # Unused, included for correctness
150+
)
151+
mlir = generate_mlir(cfg)
152+
match_lines(
153+
mlir,
154+
[
155+
"module {",
156+
"func.func @main(%arg0: tensor<512x?xf16>, %arg1: tensor<?x14366xf16>) -> tensor<512x14366xf16> {",
157+
"%cst = arith.constant 0.000000e+00 : f32",
158+
"%0 = tensor.empty() : tensor<512x14366xf32>",
159+
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x14366xf32>) -> tensor<512x14366xf32>",
160+
"%2 = linalg.matmul ins(%arg0, %arg1 : tensor<512x?xf16>, tensor<?x14366xf16>) outs(%1 : tensor<512x14366xf32>) -> tensor<512x14366xf32>",
161+
"%3 = arith.truncf %2 : tensor<512x14366xf32> to tensor<512x14366xf16>",
162+
"return %3 : tensor<512x14366xf16>",
163+
],
164+
)
165+
166+
167+
def test_n_n_f16_f32_f16_dynamic_dim_M_N():
168+
# Synthetic example (modified from test_n_t_f16_f32_f16_dynamic_dim_M)
169+
cfg = GemmConfig(
170+
M=kDynamic,
171+
N=kDynamic,
172+
K=4096,
173+
tA="N",
174+
tB="N",
175+
operand_element_type="f16",
176+
accumulator_element_type="f32",
177+
result_element_type="f16",
178+
runtime_dim=512, # Unused, included for correctness
179+
)
180+
mlir = generate_mlir(cfg)
181+
match_lines(
182+
mlir,
183+
[
184+
"module {",
185+
"func.func @main(%arg0: tensor<?x4096xf16>, %arg1: tensor<4096x?xf16>) -> tensor<?x?xf16> {",
186+
"%cst = arith.constant 0.000000e+00 : f32",
187+
"%c0 = arith.constant 0 : index",
188+
"%dim = tensor.dim %arg0, %c0 : tensor<?x4096xf16>",
189+
"%c1 = arith.constant 1 : index",
190+
"%dim_0 = tensor.dim %arg1, %c1 : tensor<4096x?xf16>",
191+
"%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>",
192+
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>",
193+
"%2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x4096xf16>, tensor<4096x?xf16>) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>",
194+
"%3 = arith.truncf %2 : tensor<?x?xf32> to tensor<?x?xf16>",
195+
"return %3 : tensor<?x?xf16>",
196+
],
197+
)
198+
199+
72200
def test_n_t_bf16_f32_bf16():
73201
# From 'llama70bmemory'
74202
cfg = GemmConfig(

0 commit comments

Comments
 (0)