Skip to content

Commit d524e10

Browse files
committed
Add average latency / iteration count options
1 parent 63e80e9 commit d524e10

File tree

2 files changed

+46
-25
lines changed

2 files changed

+46
-25
lines changed

models/turbine_models/custom_models/torchbench/cmd_opts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def is_valid_file(arg):
7777
type=bool,
7878
default=True,
7979
)
80+
p.add_argument(
81+
"--num_iters",
82+
type=int,
83+
default=10,
84+
)
8085
p.add_argument(
8186
"--output_csv",
8287
type=str,

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,15 @@
3636
# "Background_Matting": {
3737
# "dim": 16,
3838
# },
39-
"LearningToPaint": {
40-
"dim": 1024,
41-
},
39+
# "LearningToPaint": {
40+
# "dim": 1024,
41+
# },
4242
"alexnet": {
4343
"dim": 1024,
4444
},
45-
"dcgan": {
46-
"dim": 1024,
47-
},
48-
"densenet121": {
49-
"dim": 64,
50-
},
45+
# "densenet121": {
46+
# "dim": 64,
47+
# },
5148
"hf_Albert": {
5249
"dim": 32,
5350
"buffer_prefix": "albert"
@@ -109,15 +106,16 @@
109106
"timm_resnest": {
110107
"dim": 256,
111108
},
112-
"timm_vision_transformer": {
113-
"dim": 256,
114-
},
109+
# "timm_vision_transformer": {
110+
# "dim": 256,
111+
# "decomp_attn": True,
112+
# },
115113
"timm_vovnet": {
116114
"dim": 128,
117115
},
118-
"vgg16": {
119-
"dim": 128,
120-
},
116+
# "vgg16": {
117+
# "dim": 128,
118+
# },
121119
}
122120

123121
# Adapted from pytorch.benchmarks.dynamo.common.main()
@@ -213,10 +211,12 @@ def export_torchbench_model(
213211
external_weight_path = None
214212

215213
decomp_list = [torch.ops.aten.reflection_pad2d]
216-
if decomp_attn == True:
214+
if decomp_attn == True or torchbench_models_dict[model_id].get("decomp_attn"):
215+
print("decomposing attention for: " + model_id)
217216
decomp_list.extend([
218217
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
219218
torch.ops.aten._scaled_dot_product_flash_attention.default,
219+
torch.ops.aten._scaled_dot_product_flash_attention,
220220
torch.ops.aten.scaled_dot_product_attention,
221221
])
222222
with decompositions.extend_aot_decompositions(
@@ -278,21 +278,37 @@ class CompiledTorchbenchModel(CompiledModule):
278278
)
279279
return vmfb_path, external_weight_path, forward_args
280280

281-
def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_path):
281+
282+
def _run_iter(runner, inputs):
283+
start = time.time()
284+
res = runner.ctx.modules.compiled_torchbench_model["main"](*inputs)
285+
return res, time.time() - start
286+
287+
def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_path, iters):
282288
if "rocm" in device:
283289
device = "hip" + device.split("rocm")[-1]
284290
mod_runner = vmfbRunner(device, vmfb_path, weights_path)
285-
inputs = [ireert.asdevicearray(mod_runner.config.device, i.clone().detach().cpu()) for i in example_args]
286-
start = time.time()
287-
results = runner.ctx.modules.compiled_torchbench_model["main"](*inputs)
288-
latency = time.time() - start
289-
with open(csv_path, "a") as csvfile:
290-
fieldnames = ["model", "latency"]
291-
data = [{"model": model_id, "latency": latency}]
291+
inputs = torch_to_iree(mod_runner, example_args)
292+
iter_latencies = []
293+
for i in range(iters):
294+
results, iter_latency = _run_iter(mod_runner, inputs)
295+
iter_latencies.append(iter_latency)
296+
avg_latency = sum(iter_latencies) / len(iter_latencies)
297+
with open(csv_path, "w") as csvfile:
298+
fieldnames = ["model", "avg_latency"]
299+
data = [{"model": model_id, "avg_latency": avg_latency}]
292300
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
301+
writer.writeheader()
293302
writer.writerows(data)
294303

295304

305+
def torch_to_iree(iree_runner, example_args):
306+
if isinstance(example_args, dict):
307+
iree_args = [ireert.asdevicearray(iree_runner.config.device, i.clone().detach().cpu()) for i in example_args.values()]
308+
else:
309+
iree_args = [ireert.asdevicearray(iree_runner.config.device, i.clone().detach().cpu()) for i in example_args]
310+
return iree_args
311+
296312
def run_main(model_id, args, tb_dir, tb_args):
297313
print(f"exporting {model_id}")
298314
mod_str, weights_path, example_args = export_torchbench_model(
@@ -320,7 +336,7 @@ def run_main(model_id, args, tb_dir, tb_args):
320336
f.write(mod_str)
321337
print("Saved to", safe_name + ".mlir")
322338
elif args.run_benchmark:
323-
run_benchmark(args.device, mod_str, weights_path, example_args, model_id, args.output_csv)
339+
run_benchmark(args.device, mod_str, weights_path, example_args, model_id, args.output_csv, args.num_iters)
324340

325341
gc.collect()
326342

0 commit comments

Comments
 (0)