Skip to content

Commit 7f0d1e8

Browse files
committed
Add benchmarking minimally, comment out a few more models
1 parent e6072e1 commit 7f0d1e8

File tree

3 files changed

+52
-8
lines changed

3 files changed

+52
-8
lines changed

models/turbine_models/custom_models/torchbench/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# SHARK torchbench exports and benchmarks
22

3+
## Overview
4+
5+
This directory serves as a place for scripts and utilities to run a suite of benchmarked inference tasks, showing functionality and performance parity between SHARK/IREE and native torch.compile workflows. It is currently under development and benchmark numbers should not be treated as the best possible result with the current state of IREE compiler optimizations.
6+
7+
Eventually, we want this process to be a plug-in to the upstream torchbench process, and this will be accomplished by exposing the IREE methodology shown here as a compile/runtime backend for the torch benchmark classes. For now, it is set up for developers as a way to get preliminary results and achieve blanket functionality for the models listed in export.py.
8+
39
### Setup
410

511
- pip install torch+rocm packages:

models/turbine_models/custom_models/torchbench/cmd_opts.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ def is_valid_file(arg):
7272
choices=["safetensors", "irpa", "gguf", None],
7373
help="Externalizes model weights from the torch dialect IR and its successors",
7474
)
75+
p.add_argument(
76+
"--run_benchmark",
77+
type=bool,
78+
default=True,
79+
)
80+
p.add_argument(
81+
"--output_csv",
82+
type=str,
83+
default="./benchmark_results.csv",
84+
)
7585

7686
##############################################################################
7787
# Modeling and Export Options

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import gc
1010

1111
from iree.compiler.ir import Context
12+
from iree import runtime as ireert
1213
import numpy as np
1314
from shark_turbine.aot import *
1415
from shark_turbine.dynamo.passes import (
@@ -21,10 +22,12 @@
2122
from safetensors import safe_open
2223
import argparse
2324
from turbine_models.turbine_tank import turbine_tank
25+
from turbine_models.model_runner import vmfbRunner
2426

2527
from pytorch.benchmarks.dynamo.common import parse_args
2628
from pytorch.benchmarks.dynamo.torchbench import TorchBenchmarkRunner, setup_torchbench_cwd
2729

30+
import csv
2831
torchbench_models_dict = {
2932
# "BERT_pytorch": {
3033
# "dim": 128,
@@ -84,7 +87,7 @@
8487
"resnet50": {
8588
"dim": 128,
8689
},
87-
"resnet50_32x4d": {
90+
"resnext50_32x4d": {
8891
"dim": 128,
8992
},
9093
"shufflenet_v2_x1_0": {
@@ -93,9 +96,9 @@
9396
"squeezenet1_1": {
9497
"dim": 512,
9598
},
96-
"timm_nfnet": {
97-
"dim": 256,
98-
},
99+
# "timm_nfnet": {
100+
# "dim": 256,
101+
# },
99102
"timm_efficientnet": {
100103
"dim": 128,
101104
},
@@ -163,9 +166,13 @@ def export_torchbench_model(
163166
model_id,
164167
f"_{static_dim}_{precision}",
165168
)
169+
safe_name = os.path.join("generated", safe_name)
166170
if decomp_attn:
167171
safe_name += "_decomp_attn"
168172

173+
if not os.path.exists("generated"):
174+
os.mkdir("generated")
175+
169176
if input_mlir:
170177
vmfb_path = utils.compile_to_vmfb(
171178
input_mlir,
@@ -179,6 +186,7 @@ def export_torchbench_model(
179186
)
180187
return vmfb_path
181188

189+
182190
_, model_name, model, forward_args, _ = get_model_and_inputs(model_id, batch_size, tb_dir, tb_args)
183191

184192
if dtype == torch.float16:
@@ -188,7 +196,8 @@ def export_torchbench_model(
188196
if not isinstance(forward_args, dict):
189197
forward_args = [i.type(dtype) for i in forward_args]
190198
for idx, i in enumerate(forward_args):
191-
np.save(f"{model_id}_input{idx}", i.clone().detach().cpu())
199+
np.save(
200+
os.path.join("generated", f"{model_id}_input{idx}"), i.clone().detach().cpu())
192201
else:
193202
for idx, i in enumerate(forward_args.values()):
194203
np.save(f"{model_id}_input{idx}", i.clone().detach().cpu())
@@ -199,7 +208,8 @@ def export_torchbench_model(
199208
if not os.path.exists(external_weights_dir):
200209
os.mkdir(external_weights_dir)
201210
external_weight_path = os.path.join(external_weights_dir, f"{model_id}_{precision}.irpa")
202-
211+
else:
212+
external_weight_path = None
203213

204214
decomp_list = [torch.ops.aten.reflection_pad2d]
205215
if decomp_attn == True:
@@ -265,11 +275,26 @@ class CompiledTorchbenchModel(CompiledModule):
265275
return_path=not exit_on_vmfb,
266276
attn_spec=attn_spec,
267277
)
268-
return vmfb_path
278+
return vmfb_path, external_weight_path, forward_args
279+
280+
def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_path):
281+
if "rocm" in device:
282+
device = "hip" + device.split("rocm")[-1]
283+
mod_runner = vmfbRunner(device, vmfb_path, weights_path)
284+
inputs = [ireert.asdevicearray(mod_runner.config.device, i) for i in example_args]
285+
start = time.time()
286+
results = runner.ctx.modules.compiled_torchbench_model["main"](*inputs)
287+
latency = time.time() - start
288+
with open(csv_path, "a") as csvfile:
289+
fieldnames = ["model", "latency"]
290+
data = [{"model": model_id, "latency": latency}]
291+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
292+
writer.writerows(data)
293+
269294

270295
def run_main(model_id, args, tb_dir, tb_args):
271296
print(f"exporting {model_id}")
272-
mod_str = export_torchbench_model(
297+
mod_str, weights_path, example_args = export_torchbench_model(
273298
model_id,
274299
tb_dir,
275300
tb_args,
@@ -293,6 +318,9 @@ def run_main(model_id, args, tb_dir, tb_args):
293318
with open(f"{safe_name}.mlir", "w+") as f:
294319
f.write(mod_str)
295320
print("Saved to", safe_name + ".mlir")
321+
elif args.run_benchmark:
322+
run_benchmark(args.device, mod_str, weights_path, example_args, model_id, args.output_csv)
323+
296324
gc.collect()
297325

298326
if __name__ == "__main__":

0 commit comments

Comments
 (0)