Skip to content

Commit 19c832f

Browse files
authored
Add TK Wave kernels to conv benchmark (#35)
* Add option to test TKW-based conv kernels to convbench. * Only limited subset of datatypes is supported for now (only `f16xf16xf32`) * Need latest iree-turbine main --------- Signed-off-by: Ivan Butygin <[email protected]>
1 parent c3bdf8e commit 19c832f

File tree

4 files changed

+146
-7
lines changed

4 files changed

+146
-7
lines changed

.github/workflows/run_bench.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,16 @@ jobs:
3737
source bench_venv/bin/activate
3838
python convbench/conv_bench.py
3939
40+
- name: TK Convolutions
41+
run: |
42+
source bench_venv/bin/activate
43+
python convbench/conv_bench.py --tk
44+
4045
- name: Attention
4146
run: |
4247
source bench_venv/bin/activate
4348
python attentionbench/attention_bench.py
44-
49+
4550
- name: TK GEMM
4651
run: |
4752
source bench_venv/bin/activate
@@ -57,11 +62,13 @@ jobs:
5762
source bench_venv/bin/activate
5863
python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8
5964
python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_f16.png --dtype f16
65+
python convbench/conv_bench.py --roofline results/iree_conv_tk.csv --plot results/iree_conv_tk_i8.png --dtype i8
66+
python convbench/conv_bench.py --roofline results/iree_conv_tk.csv --plot results/iree_conv_tk_f16.png --dtype f16
6067
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
6168
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
6269
python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
6370
python convbench/conv_bench.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
64-
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv --plot results/combined.png
71+
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv,results/iree_conv_tk.csv --plot results/combined.png
6572
6673
- name: Upload benchmark results
6774
uses: actions/upload-artifact@v4

convbench/conv_bench.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
from conv_utils import *
1212
from problems import get_conv_configs, get_conv_test_configs
1313

14+
from wave_conv_utils import compile_wave_conv_config
1415

15-
def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
16+
def compile_conv_iree(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
1617
mlir_file, vmfb_file, dump_path = compile_conv_config(config, kernel_dir, vmfb_dir, extra_compiler_args)
1718
return (tag, config, mlir_file, vmfb_file, dump_path)
1819

20+
def compile_conv_wave(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
21+
mlir_file, vmfb_file, dump_path = compile_wave_conv_config(config, kernel_dir, vmfb_dir, extra_compiler_args)
22+
return (tag, config, mlir_file, vmfb_file, dump_path)
1923

2024
if __name__ == "__main__":
2125
parser = argparse.ArgumentParser(description="Config file updater.")
@@ -42,6 +46,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
4246
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
4347
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
4448
parser.add_argument("--model", help="roofline on certain model", default=None)
49+
parser.add_argument('--tk', help="Run conv kernels using Turbine Kernels", action=argparse.BooleanOptionalAction)
4550

4651
args = parser.parse_args()
4752
logging.basicConfig(level=args.log_level)
@@ -71,6 +76,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
7176
compile_args = itertools.starmap(
7277
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, extra_compiler_args), configs
7378
)
79+
compile_conv = compile_conv_wave if args.tk else compile_conv_iree
7480
with Pool(num_cpus) as pool:
7581
compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args))))
7682

@@ -88,7 +94,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
8894

8995
results = []
9096
index = 0
91-
output_csv = "results/iree_conv.csv"
97+
output_csv = "results/iree_conv_tk.csv" if args.tk else "results/iree_conv.csv"
98+
entrypoint = "isolated_benchmark" if args.tk else "main"
9299
csv_dir = os.path.dirname(output_csv)
93100
if not os.path.exists(csv_dir):
94101
os.makedirs(csv_dir)
@@ -105,12 +112,16 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
105112
f"--device={device}",
106113
"--device_allocator=caching",
107114
f"--module={vmfb_filename}",
108-
"--function=main",
115+
f"--function={entrypoint}",
116+
"--benchmark_repetitions=3",
109117
f"--input={image_shape}",
110118
f"--input={filter_shape}",
111-
"--benchmark_repetitions=3",
112119
]
113120

121+
if args.tk:
122+
out_shape = config.get_out_shape()
123+
exec_args.append(f"--input={out_shape}")
124+
114125
print(f"Running {vmfb_filename}...")
115126
# iree benchmark kernels
116127
ret_value, cmd_out, cmd_stderr = run_iree_command(exec_args)

convbench/conv_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,19 @@ def get_kernel_shape(self) -> str:
5656
return str(self.P) + "x" + str(self.Q) + "x" + str(self.C) + "x" + str(self.F) + "x" + self.input_dtype
5757
if "nchw" in self.OP:
5858
return str(self.F) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype
59-
59+
60+
def get_out_shape(self) -> str:
61+
padding = 0
62+
in_h = self.H * self.S + self.P - 1
63+
in_w = self.W * self.S + self.Q - 1
64+
h_out = (in_h + 2 * padding - self.P) // self.S + 1
65+
w_out = (in_w + 2 * padding - self.Q) // self.S + 1
66+
n = self.N
67+
nf = self.F
68+
if "nhwc" in self.OP:
69+
return str(n) + "x" + str(h_out) + "x" + str(w_out) + "x" + str(nf) + "x" + self.output_dtype
70+
if "nchw" in self.OP:
71+
return str(n) + "x" + str(nf) + "x" + str(h_out) + "x" + str(w_out) + "x" + self.output_dtype
6072

6173
def get_byte_count(self) -> int:
6274
dtype_bits_map = {

convbench/wave_conv_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from utils import *
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import Optional
5+
from conv_utils import ConvConfig
6+
import traceback
7+
8+
try:
9+
import iree.turbine.kernel as tk
10+
import iree.turbine.kernel.lang as tkl
11+
from iree.turbine.kernel.wave.templates.conv import get_igemm_conv2d
12+
from iree.turbine.kernel.wave.utils import (
13+
get_default_arch,
14+
get_default_run_config,
15+
get_default_compile_config,
16+
device_randn,
17+
device_randint,
18+
device_randperm,
19+
device_zeros,
20+
)
21+
except ImportError:
22+
TURBINE_AVAILABLE = False
23+
else:
24+
TURBINE_AVAILABLE = True
25+
26+
27+
def compile_wave_conv_config(
28+
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, extra_compiler_args: list[str]
29+
) -> tuple[Path, Optional[Path]]:
30+
if not TURBINE_AVAILABLE:
31+
raise ValueError("iree.turbine package is not available")
32+
33+
mlir_file = kernel_dir / (config.get_name() + ".mlir")
34+
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
35+
files_path = vmfb_dir / config.get_name()
36+
37+
try:
38+
_compile_conv(config, mlir_file, vmfb_file)
39+
except Exception as e:
40+
error_file = vmfb_dir / (config.get_name() + "_error.txt")
41+
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
42+
with open(error_file, "w") as f:
43+
f.write(str(e))
44+
f.write(traceback.format_exc())
45+
return mlir_file, None, None
46+
47+
return mlir_file, vmfb_file, files_path
48+
49+
50+
def _decode_op(op: str) -> tuple[str, str]:
51+
if op.startswith("conv_2d_"):
52+
return "conv_2d", op[len("conv_2d_") :]
53+
54+
raise ValueError(f"Unsupported op: {op}")
55+
56+
57+
def _convert_dtype(dtype: str):
58+
dtypes = {
59+
"i8": tkl.i8,
60+
"i16": tkl.i16,
61+
"i32": tkl.i32,
62+
"i64": tkl.i64,
63+
"f16": tkl.f16,
64+
"f32": tkl.f32,
65+
"f64": tkl.f64,
66+
"bf16": tkl.bf16,
67+
}
68+
return dtypes[dtype]
69+
70+
71+
def _compile_conv(config: ConvConfig, mlir_file: Path, vmfb_file: Path):
72+
print("Compile TKW kernel", config.OP)
73+
op_type, layout = _decode_op(config.OP)
74+
75+
in_h = config.H * config.S + config.P - 1
76+
in_w = config.W * config.S + config.Q - 1
77+
if op_type == "conv_2d":
78+
conv, hyperparams = get_igemm_conv2d(
79+
layout=layout,
80+
n=config.N,
81+
h=in_h,
82+
w=in_w,
83+
c=config.C,
84+
hf=config.P,
85+
wf=config.Q,
86+
nf=config.F,
87+
stride=config.S,
88+
input_dtype=_convert_dtype(config.input_dtype),
89+
output_dtype=_convert_dtype(config.output_dtype),
90+
)
91+
else:
92+
raise ValueError(f"Unsupported op_type: {op_type}")
93+
94+
# config = get_default_run_config()
95+
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
96+
97+
with tk.gen.TestLaunchContext(
98+
hyperparams,
99+
canonicalize=True,
100+
create_vmfb_file=vmfb_file,
101+
run_config=config,
102+
schedule=False,
103+
inline=False,
104+
):
105+
mod = conv().module_op # This will generate vmfb file
106+
with open(mlir_file, "w") as f:
107+
f.write(str(mod))
108+
109+
print(f"Successfully compiled to {vmfb_file}")

0 commit comments

Comments
 (0)