Skip to content

Commit fe7a40f

Browse files
committed
Formatting 2
1 parent 49ef122 commit fe7a40f

File tree

1 file changed

+64
-27
lines changed
  • models/turbine_models/custom_models/torchbench

1 file changed

+64
-27
lines changed

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@
2626
from turbine_models.model_runner import vmfbRunner
2727

2828
from pytorch.benchmarks.dynamo.common import parse_args
29-
from pytorch.benchmarks.dynamo.torchbench import TorchBenchmarkRunner, setup_torchbench_cwd
29+
from pytorch.benchmarks.dynamo.torchbench import (
30+
TorchBenchmarkRunner,
31+
setup_torchbench_cwd,
32+
)
3033

3134
import csv
35+
3236
torchbench_models_dict = {
3337
# "BERT_pytorch": {
3438
# "dim": 128,
@@ -45,10 +49,7 @@
4549
# "densenet121": {
4650
# "dim": 64,
4751
# },
48-
"hf_Albert": {
49-
"dim": 32,
50-
"buffer_prefix": "albert"
51-
},
52+
"hf_Albert": {"dim": 32, "buffer_prefix": "albert"},
5253
# "hf_Bart": {
5354
# "dim": 16,
5455
# },
@@ -118,6 +119,7 @@
118119
# },
119120
}
120121

122+
121123
# Adapted from pytorch.benchmarks.dynamo.common.main()
122124
def get_runner(tb_dir, tb_args):
123125
if tb_dir:
@@ -134,7 +136,7 @@ def get_model_and_inputs(model_id, batch_size, tb_dir, tb_args):
134136
return runner.load_model(
135137
"cuda:0",
136138
model_id,
137-
batch_size = batch_size,
139+
batch_size=batch_size,
138140
)
139141

140142

@@ -185,9 +187,10 @@ def export_torchbench_model(
185187
)
186188
return vmfb_path
187189

190+
_, model_name, model, forward_args, _ = get_model_and_inputs(
191+
model_id, batch_size, tb_dir, tb_args
192+
)
188193

189-
_, model_name, model, forward_args, _ = get_model_and_inputs(model_id, batch_size, tb_dir, tb_args)
190-
191194
if dtype == torch.float16:
192195
model = model.half()
193196
model.to("cuda:0")
@@ -196,42 +199,48 @@ def export_torchbench_model(
196199
forward_args = [i.type(dtype) for i in forward_args]
197200
for idx, i in enumerate(forward_args):
198201
np.save(
199-
os.path.join("generated", f"{model_id}_input{idx}"), i.clone().detach().cpu())
202+
os.path.join("generated", f"{model_id}_input{idx}"),
203+
i.clone().detach().cpu(),
204+
)
200205
else:
201206
for idx, i in enumerate(forward_args.values()):
202207
np.save(f"{model_id}_input{idx}", i.clone().detach().cpu())
203208

204-
205209
mapper = {}
206-
if (external_weights_dir is not None):
210+
if external_weights_dir is not None:
207211
if not os.path.exists(external_weights_dir):
208212
os.mkdir(external_weights_dir)
209-
external_weight_path = os.path.join(external_weights_dir, f"{model_id}_{precision}.irpa")
213+
external_weight_path = os.path.join(
214+
external_weights_dir, f"{model_id}_{precision}.irpa"
215+
)
210216
else:
211217
external_weight_path = None
212218

213219
decomp_list = [torch.ops.aten.reflection_pad2d]
214220
if decomp_attn == True or torchbench_models_dict[model_id].get("decomp_attn"):
215221
print("decomposing attention for: " + model_id)
216-
decomp_list.extend([
217-
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
218-
torch.ops.aten._scaled_dot_product_flash_attention.default,
219-
torch.ops.aten._scaled_dot_product_flash_attention,
220-
torch.ops.aten.scaled_dot_product_attention,
221-
])
222+
decomp_list.extend(
223+
[
224+
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
225+
torch.ops.aten._scaled_dot_product_flash_attention.default,
226+
torch.ops.aten._scaled_dot_product_flash_attention,
227+
torch.ops.aten.scaled_dot_product_attention,
228+
]
229+
)
222230
with decompositions.extend_aot_decompositions(
223231
from_current=True,
224232
add_ops=decomp_list,
225233
):
226234
if "hf" in model_id:
235+
227236
class HF_M(torch.nn.Module):
228237
def __init__(self, model):
229238
super().__init__()
230239
self.mod = model
231-
240+
232241
def forward(self, inp):
233242
return self.mod(**inp)
234-
243+
235244
if "Bart" not in model_id:
236245
# In some transformers models, the position ids buffer is registered as non-persistent,
237246
# which makes it fail to globalize in the FX import.
@@ -244,15 +253,18 @@ def forward(self, inp):
244253
persistent=True,
245254
)
246255
fxb = FxProgramsBuilder(HF_M(model))
256+
247257
@fxb.export_program(args=(forward_args,))
248258
def _forward(module: HF_M(model), inputs):
249259
return module(inputs)
260+
250261
else:
251262
fxb = FxProgramsBuilder(model)
263+
252264
@fxb.export_program(args=(forward_args,))
253265
def _forward(module, inputs):
254266
return module(*inputs)
255-
267+
256268
class CompiledTorchbenchModel(CompiledModule):
257269
main = _forward
258270

@@ -284,7 +296,10 @@ def _run_iter(runner, inputs):
284296
res = runner.ctx.modules.compiled_torchbench_model["main"](*inputs)
285297
return res, time.time() - start
286298

287-
def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_path, iters):
299+
300+
def run_benchmark(
301+
device, vmfb_path, weights_path, example_args, model_id, csv_path, iters
302+
):
288303
if "rocm" in device:
289304
device = "hip" + device.split("rocm")[-1]
290305
mod_runner = vmfbRunner(device, vmfb_path, weights_path)
@@ -301,7 +316,13 @@ def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_p
301316
needs_header = False
302317
with open(csv_path, "a") as csvfile:
303318
fieldnames = ["model", "avg_latency", "avg_iter_per_sec"]
304-
data = [{"model": model_id, "avg_latency": avg_latency, "avg_iter_per_sec": it_per_sec}]
319+
data = [
320+
{
321+
"model": model_id,
322+
"avg_latency": avg_latency,
323+
"avg_iter_per_sec": it_per_sec,
324+
}
325+
]
305326
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
306327
if needs_header:
307328
writer.writeheader()
@@ -311,11 +332,18 @@ def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_p
311332

312333
def torch_to_iree(iree_runner, example_args):
313334
if isinstance(example_args, dict):
314-
iree_args = [ireert.asdevicearray(iree_runner.config.device, i.clone().detach().cpu()) for i in example_args.values()]
335+
iree_args = [
336+
ireert.asdevicearray(iree_runner.config.device, i.clone().detach().cpu())
337+
for i in example_args.values()
338+
]
315339
else:
316-
iree_args = [ireert.asdevicearray(iree_runner.config.device, i.clone().detach().cpu()) for i in example_args]
340+
iree_args = [
341+
ireert.asdevicearray(iree_runner.config.device, i.clone().detach().cpu())
342+
for i in example_args
343+
]
317344
return iree_args
318345

346+
319347
def run_main(model_id, args, tb_dir, tb_args):
320348
print(f"exporting {model_id}")
321349
mod_str, weights_path, example_args = export_torchbench_model(
@@ -343,16 +371,25 @@ def run_main(model_id, args, tb_dir, tb_args):
343371
f.write(mod_str)
344372
print("Saved to", safe_name + ".mlir")
345373
elif args.run_benchmark:
346-
run_benchmark(args.device, mod_str, weights_path, example_args, model_id, args.output_csv, args.num_iters)
374+
run_benchmark(
375+
args.device,
376+
mod_str,
377+
weights_path,
378+
example_args,
379+
model_id,
380+
args.output_csv,
381+
args.num_iters,
382+
)
347383

348384
gc.collect()
349385

386+
350387
if __name__ == "__main__":
351388
from turbine_models.custom_models.torchbench.cmd_opts import args, unknown
389+
352390
tb_dir = setup_torchbench_cwd()
353391
if args.model_id.lower() == "all":
354392
for name in torchbench_models_dict.keys():
355393
run_main(name, args, tb_dir, unknown)
356394
else:
357395
run_main(args.model_id, args, tb_dir, unknown)
358-

0 commit comments

Comments
 (0)