2626from turbine_models .model_runner import vmfbRunner
2727
2828from pytorch .benchmarks .dynamo .common import parse_args
29- from pytorch .benchmarks .dynamo .torchbench import TorchBenchmarkRunner , setup_torchbench_cwd
29+ from pytorch .benchmarks .dynamo .torchbench import (
30+ TorchBenchmarkRunner ,
31+ setup_torchbench_cwd ,
32+ )
3033
3134import csv
35+
3236torchbench_models_dict = {
3337 # "BERT_pytorch": {
3438 # "dim": 128,
4549 # "densenet121": {
4650 # "dim": 64,
4751 # },
48- "hf_Albert" : {
49- "dim" : 32 ,
50- "buffer_prefix" : "albert"
51- },
52+ "hf_Albert" : {"dim" : 32 , "buffer_prefix" : "albert" },
5253 # "hf_Bart": {
5354 # "dim": 16,
5455 # },
118119 # },
119120}
120121
122+
121123# Adapted from pytorch.benchmarks.dynamo.common.main()
122124def get_runner (tb_dir , tb_args ):
123125 if tb_dir :
@@ -134,7 +136,7 @@ def get_model_and_inputs(model_id, batch_size, tb_dir, tb_args):
134136 return runner .load_model (
135137 "cuda:0" ,
136138 model_id ,
137- batch_size = batch_size ,
139+ batch_size = batch_size ,
138140 )
139141
140142
@@ -185,9 +187,10 @@ def export_torchbench_model(
185187 )
186188 return vmfb_path
187189
190+ _ , model_name , model , forward_args , _ = get_model_and_inputs (
191+ model_id , batch_size , tb_dir , tb_args
192+ )
188193
189- _ , model_name , model , forward_args , _ = get_model_and_inputs (model_id , batch_size , tb_dir , tb_args )
190-
191194 if dtype == torch .float16 :
192195 model = model .half ()
193196 model .to ("cuda:0" )
@@ -196,42 +199,48 @@ def export_torchbench_model(
196199 forward_args = [i .type (dtype ) for i in forward_args ]
197200 for idx , i in enumerate (forward_args ):
198201 np .save (
199- os .path .join ("generated" , f"{ model_id } _input{ idx } " ), i .clone ().detach ().cpu ())
202+ os .path .join ("generated" , f"{ model_id } _input{ idx } " ),
203+ i .clone ().detach ().cpu (),
204+ )
200205 else :
201206 for idx , i in enumerate (forward_args .values ()):
202207 np .save (f"{ model_id } _input{ idx } " , i .clone ().detach ().cpu ())
203208
204-
205209 mapper = {}
206- if ( external_weights_dir is not None ) :
210+ if external_weights_dir is not None :
207211 if not os .path .exists (external_weights_dir ):
208212 os .mkdir (external_weights_dir )
209- external_weight_path = os .path .join (external_weights_dir , f"{ model_id } _{ precision } .irpa" )
213+ external_weight_path = os .path .join (
214+ external_weights_dir , f"{ model_id } _{ precision } .irpa"
215+ )
210216 else :
211217 external_weight_path = None
212218
213219 decomp_list = [torch .ops .aten .reflection_pad2d ]
214220 if decomp_attn == True or torchbench_models_dict [model_id ].get ("decomp_attn" ):
215221 print ("decomposing attention for: " + model_id )
216- decomp_list .extend ([
217- torch .ops .aten ._scaled_dot_product_flash_attention_for_cpu ,
218- torch .ops .aten ._scaled_dot_product_flash_attention .default ,
219- torch .ops .aten ._scaled_dot_product_flash_attention ,
220- torch .ops .aten .scaled_dot_product_attention ,
221- ])
222+ decomp_list .extend (
223+ [
224+ torch .ops .aten ._scaled_dot_product_flash_attention_for_cpu ,
225+ torch .ops .aten ._scaled_dot_product_flash_attention .default ,
226+ torch .ops .aten ._scaled_dot_product_flash_attention ,
227+ torch .ops .aten .scaled_dot_product_attention ,
228+ ]
229+ )
222230 with decompositions .extend_aot_decompositions (
223231 from_current = True ,
224232 add_ops = decomp_list ,
225233 ):
226234 if "hf" in model_id :
235+
227236 class HF_M (torch .nn .Module ):
228237 def __init__ (self , model ):
229238 super ().__init__ ()
230239 self .mod = model
231-
240+
232241 def forward (self , inp ):
233242 return self .mod (** inp )
234-
243+
235244 if "Bart" not in model_id :
236245 # In some transformers models, the position ids buffer is registered as non-persistent,
237246 # which makes it fail to globalize in the FX import.
@@ -244,15 +253,18 @@ def forward(self, inp):
244253 persistent = True ,
245254 )
246255 fxb = FxProgramsBuilder (HF_M (model ))
256+
247257 @fxb .export_program (args = (forward_args ,))
248258 def _forward (module : HF_M (model ), inputs ):
249259 return module (inputs )
260+
250261 else :
251262 fxb = FxProgramsBuilder (model )
263+
252264 @fxb .export_program (args = (forward_args ,))
253265 def _forward (module , inputs ):
254266 return module (* inputs )
255-
267+
256268 class CompiledTorchbenchModel (CompiledModule ):
257269 main = _forward
258270
@@ -284,7 +296,10 @@ def _run_iter(runner, inputs):
284296 res = runner .ctx .modules .compiled_torchbench_model ["main" ](* inputs )
285297 return res , time .time () - start
286298
287- def run_benchmark (device , vmfb_path , weights_path , example_args , model_id , csv_path , iters ):
299+
300+ def run_benchmark (
301+ device , vmfb_path , weights_path , example_args , model_id , csv_path , iters
302+ ):
288303 if "rocm" in device :
289304 device = "hip" + device .split ("rocm" )[- 1 ]
290305 mod_runner = vmfbRunner (device , vmfb_path , weights_path )
@@ -301,7 +316,13 @@ def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_p
301316 needs_header = False
302317 with open (csv_path , "a" ) as csvfile :
303318 fieldnames = ["model" , "avg_latency" , "avg_iter_per_sec" ]
304- data = [{"model" : model_id , "avg_latency" : avg_latency , "avg_iter_per_sec" : it_per_sec }]
319+ data = [
320+ {
321+ "model" : model_id ,
322+ "avg_latency" : avg_latency ,
323+ "avg_iter_per_sec" : it_per_sec ,
324+ }
325+ ]
305326 writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
306327 if needs_header :
307328 writer .writeheader ()
@@ -311,11 +332,18 @@ def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_p
311332
312333def torch_to_iree (iree_runner , example_args ):
313334 if isinstance (example_args , dict ):
314- iree_args = [ireert .asdevicearray (iree_runner .config .device , i .clone ().detach ().cpu ()) for i in example_args .values ()]
335+ iree_args = [
336+ ireert .asdevicearray (iree_runner .config .device , i .clone ().detach ().cpu ())
337+ for i in example_args .values ()
338+ ]
315339 else :
316- iree_args = [ireert .asdevicearray (iree_runner .config .device , i .clone ().detach ().cpu ()) for i in example_args ]
340+ iree_args = [
341+ ireert .asdevicearray (iree_runner .config .device , i .clone ().detach ().cpu ())
342+ for i in example_args
343+ ]
317344 return iree_args
318345
346+
319347def run_main (model_id , args , tb_dir , tb_args ):
320348 print (f"exporting { model_id } " )
321349 mod_str , weights_path , example_args = export_torchbench_model (
@@ -343,16 +371,25 @@ def run_main(model_id, args, tb_dir, tb_args):
343371 f .write (mod_str )
344372 print ("Saved to" , safe_name + ".mlir" )
345373 elif args .run_benchmark :
346- run_benchmark (args .device , mod_str , weights_path , example_args , model_id , args .output_csv , args .num_iters )
374+ run_benchmark (
375+ args .device ,
376+ mod_str ,
377+ weights_path ,
378+ example_args ,
379+ model_id ,
380+ args .output_csv ,
381+ args .num_iters ,
382+ )
347383
348384 gc .collect ()
349385
386+
350387if __name__ == "__main__" :
351388 from turbine_models .custom_models .torchbench .cmd_opts import args , unknown
389+
352390 tb_dir = setup_torchbench_cwd ()
353391 if args .model_id .lower () == "all" :
354392 for name in torchbench_models_dict .keys ():
355393 run_main (name , args , tb_dir , unknown )
356394 else :
357395 run_main (args .model_id , args , tb_dir , unknown )
358-
0 commit comments