Skip to content

Commit 4621947

Browse files
authored
Add SDXL conv shapes, extra iree flags option, tool to plot roofline percentages (#19)
- Adds the SDXL convolution shapes to convbench - Adds the option to pass Xiree_compile flags in convbench - Adds percentage of roofline to the collected conv benchmark metrics - Adds a tool to plot roofline percents against kernel parameters given the benchmarks and kernel stats - Renames `shark_conv.py` to `conv_bench.py` to match gemm and attention formats --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 982eb72 commit 4621947

File tree

8 files changed

+201
-32
lines changed

8 files changed

+201
-32
lines changed

.github/workflows/run_bench.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
- name: Convolutions
3636
run: |
3737
source bench_venv/bin/activate
38-
python convbench/shark_conv.py
38+
python convbench/conv_bench.py
3939
4040
- name: Attention
4141
run: |
@@ -55,13 +55,13 @@ jobs:
5555
- name: Roofline Plots
5656
run: |
5757
source bench_venv/bin/activate
58-
python convbench/shark_conv.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8
59-
python convbench/shark_conv.py --roofline results/iree_conv.csv --plot results/iree_conv_f32.png --dtype f32
60-
python convbench/shark_conv.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
61-
python convbench/shark_conv.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
62-
python convbench/shark_conv.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
63-
python convbench/shark_conv.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
64-
python convbench/shark_conv.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv --plot results/combined.png
58+
python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_i8.png --dtype i8
59+
python convbench/conv_bench.py --roofline results/iree_conv.csv --plot results/iree_conv_f16.png --dtype f16
60+
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
61+
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
62+
python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
63+
python convbench/conv_bench.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
64+
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv --plot results/combined.png
6565
6666
- name: Upload benchmark results
6767
uses: actions/upload-artifact@v4

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Refer to the respective problems.py file in the folder to see which shapes are b
2424
### Convolution Benchmarking
2525

2626
```
27-
python convbench/shark_conv.py
27+
python convbench/conv_bench.py
2828
```
2929

3030
### GEMM Benchmarking
@@ -50,7 +50,7 @@ python attentionbench/attention_bench.py
5050
If you want to generate a roofline plot, you can call any of the suites for now with the --roofline option (provide a commma seperated list if you want to generate for multiple benchmarks combined):
5151

5252
```
53-
python convbench/shark_conv.py --roofline results/iree_conv.csv,results/iree_attention.csv --plot results/attn_conv.png
53+
python convbench/conv_bench.py --roofline results/iree_conv.csv,results/iree_attention.csv --plot results/attn_conv.png
5454
```
5555

5656
If you want to generate a roofline plot for a certain data type, model, or batch size you can do:

common_tools/kernel_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class KernelStats:
9898
@staticmethod
9999
def get_csv_header() -> list[str]:
100100
return (
101-
["Name"] + IsaStats.get_csv_header() + ConfiguredMlirStats.get_csv_header()
101+
["name"] + IsaStats.get_csv_header() + ConfiguredMlirStats.get_csv_header()
102102
)
103103

104104
def get_values(self):
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import argparse
2+
import pandas as pd
3+
import matplotlib.pyplot as plt
4+
5+
def plot_roofline_vs_column(kernel_stat_path, benchmark_stat_path, out_path, param_name, boxplot):
6+
kernel_df = pd.read_csv(kernel_stat_path)
7+
benchmark_df = pd.read_csv(benchmark_stat_path)
8+
if param_name not in kernel_df.columns and param_name not in benchmark_df.columns:
9+
print(f"`{param_name}` column not found in {kernel_stat_path} or {benchmark_stat_path}.\n")
10+
return False
11+
if "roofline_percent" not in benchmark_df.columns:
12+
print(f"`roofline_percent` column not found in {benchmark_stat_path}.\n")
13+
return False
14+
if "name" not in benchmark_df.columns or "name" not in kernel_df.columns:
15+
print(f"`name` column not found in {kernel_stat_path} and {benchmark_stat_path}.\n")
16+
return False
17+
df = kernel_df.merge(benchmark_df, on="name")
18+
if boxplot:
19+
axes = df[[param_name, "roofline_percent"]].boxplot(
20+
by=param_name,
21+
figsize=(12,12)
22+
)
23+
else:
24+
axes = df.plot(
25+
param_name,
26+
"roofline_percent",
27+
kind="scatter",
28+
figsize=(12,12)
29+
)
30+
plt.xlabel(param_name)
31+
plt.ylabel("roofline_percent")
32+
plt.savefig(out_path, dpi=300, bbox_inches='tight')
33+
plt.close()
34+
return True
35+
36+
37+
if __name__ == "__main__":
38+
parser = argparse.ArgumentParser(
39+
description="Plotting tool to correlate kernel parameters with roofline percentages."
40+
)
41+
parser.add_argument(
42+
"--kernel_stats_csv",
43+
help="The path to the input csv containing kernel metrics.",
44+
type=str,
45+
default=None
46+
)
47+
parser.add_argument(
48+
"--benchmark_csv",
49+
help="The path to the input csv containing benchmarks.",
50+
type=str,
51+
default=None
52+
)
53+
parser.add_argument(
54+
"--out_path",
55+
help="The path to save the resulting plot image.",
56+
type=str,
57+
default=None
58+
)
59+
parser.add_argument(
60+
"--parameter",
61+
help="The name of the column with the parameter to use as the x-axis.",
62+
type=str,
63+
default=None
64+
)
65+
parser.add_argument(
66+
"--boxplot",
67+
help="Use a boxplot graph, with one boxplot per parameter value.",
68+
action=argparse.BooleanOptionalAction,
69+
type=bool,
70+
default=False
71+
)
72+
args = parser.parse_args()
73+
74+
succeeded = plot_roofline_vs_column(
75+
args.kernel_stats_csv, args.benchmark_csv, args.out_path, args.parameter, args.boxplot
76+
)
77+
if succeeded:
78+
print(f"Plot saved to {args.out_path}\n")
79+
else:
80+
print(f"Failed to generate plot.\n")

common_tools/utils/bench_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def roofline(results=None, out=None, batch=None, dtype=None, model=None, **kwarg
144144
with open(result_file.strip(), mode='r') as csvfile:
145145
reader = csv.DictReader(csvfile)
146146
for row in reader:
147-
row = {k: float(v) if k in ['index', 'mean_microseconds', 'arithmetic_intensity', 'tflops'] else v for k, v in row.items()}
147+
row = {k: float(v) if k in ['index', 'mean_microseconds', 'arithmetic_intensity', 'tflops', 'roofline_tflops', 'roofline_percent'] else v for k, v in row.items()}
148148
row['ok'] = True if 'ok' not in row else row['ok'] == 'True'
149149
data.append(row)
150150
if batch:

convbench/shark_conv.py renamed to convbench/conv_bench.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import sys
1010
from utils import *
1111
from conv_utils import *
12-
from problems import get_conv_configs
12+
from problems import get_conv_configs, get_conv_test_configs
1313

1414

15-
def compile_conv(tag, config, kernel_dir, vmfb_dir):
16-
mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir)
17-
return (tag, config, mlir_file, vmfb_file)
15+
def compile_conv(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
16+
mlir_file, vmfb_file, dump_path = compile_conv_config(config, kernel_dir, vmfb_dir, extra_compiler_args)
17+
return (tag, config, mlir_file, vmfb_file, dump_path)
1818

1919

2020
if __name__ == "__main__":
@@ -27,6 +27,12 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
2727
help="Set the logging level",
2828
)
2929
parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip")
30+
parser.add_argument(
31+
"--Xiree_compile",
32+
nargs='+',
33+
default=[],
34+
help="Extra command line arguments passed to the IREE compiler. The flags need to be specified without the `--` or `-`."
35+
)
3036
parser.add_argument(
3137
"--roofline",
3238
help="Comma seperated csv file list to generate roofline plot with",
@@ -44,6 +50,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
4450
roofline(args.roofline, args.plot, args.batch, args.dtype, args.model)
4551
sys.exit()
4652

53+
# configs = get_conv_test_configs()
4754
configs = get_conv_configs()
4855
print(f"Generated {len(configs)} conv configs.")
4956

@@ -60,16 +67,17 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
6067
vmfb_dir.mkdir(parents=True, exist_ok=True)
6168
device = args.device
6269

70+
extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)]
6371
compile_args = itertools.starmap(
64-
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
72+
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, extra_compiler_args), configs
6573
)
6674
with Pool(num_cpus) as pool:
6775
compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args))))
6876

6977
error_count = 0
70-
for tag, config, mlir_file, vmfb_file in compilation_results:
78+
for tag, config, mlir_file, vmfb_file, dump_path in compilation_results:
7179
if vmfb_file:
72-
vmfb_dict[vmfb_file] = (tag, config)
80+
vmfb_dict[vmfb_file] = (tag, config, dump_path)
7381
else:
7482
error_count += 1
7583
print(
@@ -86,7 +94,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
8694
os.makedirs(csv_dir)
8795

8896
for vmfb_filename, value in vmfb_dict.items():
89-
tag, config = value
97+
tag, config, dump_path = value
9098
name = config.get_name()
9199

92100
image_shape = config.get_img_shape()
@@ -103,17 +111,29 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
103111
"--benchmark_repetitions=3",
104112
]
105113

114+
print(f"Running {vmfb_filename}...")
106115
# iree benchmark kernels
107116
ret_value, cmd_out, cmd_stderr = run_iree_command(exec_args)
108117
ok = ret_value == 0
109-
benchmark_gemm_mean_time_ms = bench_summary_process(ret_value, cmd_out)
110-
benchmark_gemm_mean_time_us = benchmark_gemm_mean_time_ms * 1000
118+
benchmark_conv_mean_time_ms = bench_summary_process(ret_value, cmd_out)
119+
benchmark_conv_mean_time_us = benchmark_conv_mean_time_ms * 1000
111120

112121
flops = config.get_flops()
113122
byte_count = config.get_byte_count()
114123

115124
arithmetic_intensity = flops / byte_count
116-
tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6)
125+
tflops_per_second = (flops / 1e12) / (benchmark_conv_mean_time_us / 1e6)
126+
127+
# Compute percentage of the roofline.
128+
# TODO: Make this target specific and move to common utils.
129+
tflops_map = {
130+
"f32": 653.7,
131+
"f16": 1307.4,
132+
"bf16": 1307.4,
133+
"f8E4M3FNUZ": 2614.9,
134+
"i8": 2614.9,
135+
}
136+
roofline_tflops = tflops_map[config.input_dtype]
117137

118138
results.append(
119139
(
@@ -130,9 +150,11 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
130150
config.S,
131151
config.input_dtype,
132152
config.output_dtype,
133-
round(benchmark_gemm_mean_time_us, 4),
153+
round(benchmark_conv_mean_time_us, 4),
134154
round(arithmetic_intensity, 4),
135155
round(tflops_per_second, 4),
156+
roofline_tflops,
157+
round(tflops_per_second / roofline_tflops, 4),
136158
ok,
137159
)
138160
)
@@ -155,6 +177,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
155177
"mean_microseconds",
156178
"arithmetic_intensity",
157179
"tflops",
180+
"roofline_tflops",
181+
"roofline_percent",
158182
"ok",
159183
]
160184

convbench/conv_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,12 @@ def generate_mlir(config: ConvConfig):
162162

163163

164164
def compile_conv_config(
165-
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path
165+
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, extra_compiler_args: list[str]
166166
) -> tuple[Path, Optional[Path]]:
167167
mlir_file = kernel_dir / (config.get_name() + ".mlir")
168168
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
169169
dump_file = kernel_dir / (config.get_name() + ".stderr.mlir")
170+
files_path = vmfb_dir / config.get_name()
170171

171172
# Generate mlir content
172173
mlir_content = generate_mlir(config)
@@ -188,7 +189,8 @@ def compile_conv_config(
188189
"--iree-hal-target-device=hip",
189190
# Device: MI300x
190191
"--iree-hip-target=gfx942",
191-
]
192+
f"--iree-hal-dump-executable-files-to={files_path}",
193+
] + extra_compiler_args
192194

193195
print(" ".join(exec_args))
194196

@@ -203,6 +205,6 @@ def compile_conv_config(
203205
print(f"Failed to compile {mlir_file}. Error dumped in {error_file}")
204206
with open(error_file, "w") as f:
205207
f.write(stderr.decode("utf-8"))
206-
return mlir_file, None
208+
return mlir_file, None, None
207209

208-
return mlir_file, vmfb_file
210+
return mlir_file, vmfb_file, files_path

convbench/problems.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,41 @@
11
from conv_utils import ConvConfig
22

33

4+
def unet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
5+
configs = []
6+
for B in [1, 2, 4, 8]:
7+
configs.append(ConvConfig(B, 128, 128, 16, 3, 3, 320, 1, op, input_dtype, output_dtype))
8+
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 320, 1, op, input_dtype, output_dtype))
9+
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 320, 2, op, input_dtype, output_dtype))
10+
configs.append(ConvConfig(B, 64, 64, 320, 3, 3, 640, 1, op, input_dtype, output_dtype))
11+
configs.append(ConvConfig(B, 64, 64, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
12+
configs.append(ConvConfig(B, 64, 64, 320, 1, 1, 640, 1, op, input_dtype, output_dtype))
13+
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 640, 2, op, input_dtype, output_dtype))
14+
configs.append(ConvConfig(B, 32, 32, 640, 3, 3, 1280, 1, op, input_dtype, output_dtype))
15+
configs.append(ConvConfig(B, 32, 32, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
16+
configs.append(ConvConfig(B, 32, 32, 640, 1, 1, 1280, 1, op, input_dtype, output_dtype))
17+
configs.append(ConvConfig(B, 32, 32, 2560, 3, 3, 1280, 1, op, input_dtype, output_dtype))
18+
configs.append(ConvConfig(B, 32, 32, 2560, 1, 1, 1280, 1, op, input_dtype, output_dtype))
19+
configs.append(ConvConfig(B, 32, 32, 1920, 3, 3, 1280, 1, op, input_dtype, output_dtype))
20+
configs.append(ConvConfig(B, 32, 32, 1920, 1, 1, 1280, 1, op, input_dtype, output_dtype))
21+
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 1280, 1, op, input_dtype, output_dtype))
22+
configs.append(ConvConfig(B, 64, 64, 1920, 3, 3, 640, 1, op, input_dtype, output_dtype))
23+
configs.append(ConvConfig(B, 64, 64, 1920, 1, 1, 640, 1, op, input_dtype, output_dtype))
24+
configs.append(ConvConfig(B, 64, 64, 1280, 3, 3, 640, 1, op, input_dtype, output_dtype))
25+
configs.append(ConvConfig(B, 64, 64, 1280, 1, 1, 640, 1, op, input_dtype, output_dtype))
26+
configs.append(ConvConfig(B, 64, 64, 960, 3, 3, 640, 1, op, input_dtype, output_dtype))
27+
configs.append(ConvConfig(B, 64, 64, 960, 1, 1, 640, 1, op, input_dtype, output_dtype))
28+
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 640, 1, op, input_dtype, output_dtype))
29+
configs.append(ConvConfig(B, 128, 128, 960, 3, 3, 320, 1, op, input_dtype, output_dtype))
30+
configs.append(ConvConfig(B, 128, 128, 960, 1, 1, 320, 1, op, input_dtype, output_dtype))
31+
configs.append(ConvConfig(B, 128, 128, 640, 3, 3, 320, 1, op, input_dtype, output_dtype))
32+
configs.append(ConvConfig(B, 128, 128, 640, 1, 1, 320, 1, op, input_dtype, output_dtype))
33+
configs.append(ConvConfig(B, 128, 128, 320, 3, 3, 16, 1, op, input_dtype, output_dtype))
34+
return configs
35+
436
def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfig]:
537
configs = []
6-
for B in [1, 2, 4, 8, 16, 32, 48]:
38+
for B in [1, 2, 4, 8]:
739
configs.append(ConvConfig(B, 112, 112, 64, 7, 7, 3, 2, op, input_dtype, output_dtype))
840
configs.append(ConvConfig(B, 56, 56, 64, 3, 3, 64, 1, op, input_dtype, output_dtype))
941
configs.append(ConvConfig(B, 28, 28, 128, 3, 3, 128, 2, op, input_dtype, output_dtype))
@@ -19,9 +51,40 @@ def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfi
1951

2052
def get_conv_configs() -> list[tuple[str, ConvConfig]]:
2153
configs: list[tuple[str, ConvConfig]] = []
22-
resnet_configs = resnet_sweep("conv_2d_nchw_fchw", "f32", "f32")
23-
resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf_q", "i8", "i32")
2454

25-
configs += [("resnet_sweep", x) for x in resnet_configs]
55+
# Resnet
56+
resnet_configs = []
57+
resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
58+
resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "i8", "i32")
59+
resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "f16", "f32")
60+
resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "i8", "i32")
61+
configs += [("resnet", x) for x in resnet_configs]
62+
63+
# Unet
64+
unet_configs = []
65+
unet_configs += unet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
66+
unet_configs += unet_sweep("conv_2d_nhwc_hwcf", "i8", "i32")
67+
unet_configs += unet_sweep("conv_2d_nchw_fchw", "f16", "f32")
68+
unet_configs += unet_sweep("conv_2d_nchw_fchw", "i8", "i32")
69+
configs += [("unet", x) for x in unet_configs]
70+
71+
return configs
72+
73+
# Test function to run only a few chosen shapes
74+
def get_conv_test_configs() -> list[tuple[str, ConvConfig]]:
75+
configs: list[tuple[str, ConvConfig]] = []
76+
77+
resnet_configs = []
78+
# resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
79+
# resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "i8", "i32")
80+
# resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "f16", "f32")
81+
# resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "i8", "i32")
82+
configs += [("resnet", x) for x in resnet_configs]
83+
84+
unet_configs = []
85+
# unet_configs.append(ConvConfig(1,128,128,16,3,3,320,1, "conv_2d_nhwc_hwcf_q", "i8", "i32"))
86+
# unet_configs.append(ConvConfig(1,32,32,640,1,1,1280,1, "conv_2d_nhwc_hwcf_q", "i8", "i32"))
87+
88+
configs += [("unet", x) for x in unet_configs]
2689

2790
return configs

0 commit comments

Comments
 (0)