3636torchbench_models_dict = {
3737 # "BERT_pytorch": {
3838 # "dim": 128,
39- # },
39+ # }, # Dynamo Export Issue
4040 # "Background_Matting": {
4141 # "dim": 16,
42- # },
43- # "LearningToPaint": {
44- # "dim": 1024,
45- # },
42+ # }, # Transpose Bubbling Pattern Failed
43+ "LearningToPaint" : {
44+ "dim" : 1024 ,
45+ },
4646 "alexnet" : {
4747 "dim" : 1024 ,
4848 },
49- # "densenet121": {
50- # "dim": 64,
51- # },
49+ "densenet121" : {
50+ "dim" : 64 ,
51+ },
5252 # "hf_Albert": {"dim": 32, "buffer_prefix": "albert"},
5353 # "hf_Bart": {
5454 # "dim": 16,
@@ -131,17 +131,28 @@ def get_runner(tb_dir, tb_args):
131131 return runner
132132
133133
134- def get_model_and_inputs (model_id , batch_size , tb_dir , tb_args ):
134+ def get_model_and_inputs (model_id , batch_size , tb_dir , tb_args , get_baseline = False ):
135135 runner = get_runner (tb_dir , tb_args )
136- return runner .load_model (
136+ _ , model_name , model , forward_args , _ = runner .load_model (
137137 "cuda:0" ,
138138 model_id ,
139139 batch_size = batch_size ,
140140 )
141-
142-
141+ match get_baseline :
142+ case True :
143+ start_t = time .time ()
144+ res = runner .forward_pass (model , forward_args , collect_outputs = True )
145+ baseline = time .time () - start_t
146+ return model_name , model , forward_args , res , baseline
147+ case False :
148+ return model_name , model , forward_args
149+
150+
151+ '''
152+ Imports models from torchbench model tooling, exports them with turbine AOT, and does simple benchmarking.
153+ '''
143154@torch .no_grad ()
144- def export_torchbench_model (
155+ def benchmark_torchbench_model (
145156 model_id ,
146157 tb_dir ,
147158 tb_args ,
@@ -159,6 +170,7 @@ def export_torchbench_model(
159170 input_mlir = None ,
160171 weights_only = False ,
161172 upload_ir = False ,
173+ compare_vs_eager = False ,
162174):
163175 static_dim = torchbench_models_dict [model_id ]["dim" ]
164176 dtype = torch .float16 if precision == "fp16" else torch .float32
@@ -187,9 +199,16 @@ def export_torchbench_model(
187199 )
188200 return vmfb_path
189201
190- _ , model_name , model , forward_args , _ = get_model_and_inputs (
191- model_id , batch_size , tb_dir , tb_args
192- )
202+ if compare_vs_eager :
203+ model_name , model , forward_args , golden , baseline = get_model_and_inputs (
204+ model_id , batch_size , tb_dir , tb_args , get_baseline = True
205+ )
206+ else :
207+ model_name , model , forward_args = get_model_and_inputs (
208+ model_id , batch_size , tb_dir , tb_args
209+ )
210+ golden = None
211+ baseline = None
193212
194213 if dtype == torch .float16 :
195214 model = model .half ()
@@ -275,7 +294,8 @@ class CompiledTorchbenchModel(CompiledModule):
275294 inst = CompiledTorchbenchModel (context = Context (), import_to = "IMPORT" )
276295
277296 module = CompiledModule .get_mlir_module (inst )
278-
297+ model .to ("cpu" )
298+ del model
279299 if compile_to != "vmfb" :
280300 return str (module )
281301 else :
@@ -288,17 +308,21 @@ class CompiledTorchbenchModel(CompiledModule):
288308 return_path = not exit_on_vmfb ,
289309 attn_spec = attn_spec ,
290310 )
291- return vmfb_path , external_weight_path , forward_args
311+ return vmfb_path , external_weight_path , forward_args , golden , baseline
292312
293313
294314def _run_iter (runner , inputs ):
295315 start = time .time ()
296316 res = runner .ctx .modules .compiled_torchbench_model ["main" ](* inputs )
297317 return res , time .time () - start
298318
319+ def do_compare (shark_results , shark_latency , golden_results , golden_latency ):
320+ numerics_pass_fail = np .allclose (shark_results .to_host (), golden_results .clone ().cpu ().numpy (), rtol = 1e-4 , atol = 1e-4 )
321+ speedup = golden_latency / shark_latency
322+ return speedup , numerics_pass_fail
299323
300324def run_benchmark (
301- device , vmfb_path , weights_path , example_args , model_id , csv_path , iters
325+ device , vmfb_path , weights_path , example_args , model_id , csv_path , iters , golden = None , baseline = None ,
302326):
303327 if "rocm" in device :
304328 device = "hip" + device .split ("rocm" )[- 1 ]
@@ -311,16 +335,23 @@ def run_benchmark(
311335 avg_latency = sum (iter_latencies ) / len (iter_latencies )
312336 it_per_sec = 1 / avg_latency
313337
338+ if golden is not None and baseline is not None :
339+ speedup , numerics_pass_fail = do_compare (results , avg_latency , golden , baseline )
340+ else :
341+ speedup , numerics_pass_fail = ("N/A" , "N/A" )
342+
314343 needs_header = True
315344 if os .path .exists (csv_path ):
316345 needs_header = False
317346 with open (csv_path , "a" ) as csvfile :
318- fieldnames = ["model" , "avg_latency" , "avg_iter_per_sec" ]
347+ fieldnames = ["model" , "avg_latency" , "avg_iter_per_sec" , "speedup_over_eager" , "numerics" ]
319348 data = [
320349 {
321350 "model" : model_id ,
322351 "avg_latency" : avg_latency ,
323352 "avg_iter_per_sec" : it_per_sec ,
353+ "speedup_over_eager" : speedup ,
354+ "numerics" : numerics_pass_fail ,
324355 }
325356 ]
326357 writer = csv .DictWriter (csvfile , fieldnames = fieldnames )
@@ -346,7 +377,7 @@ def torch_to_iree(iree_runner, example_args):
346377
347378def run_main (model_id , args , tb_dir , tb_args ):
348379 print (f"exporting { model_id } " )
349- mod_str , weights_path , example_args = export_torchbench_model (
380+ mod_str , weights_path , example_args , golden , baseline = benchmark_torchbench_model (
350381 model_id ,
351382 tb_dir ,
352383 tb_args ,
@@ -361,6 +392,7 @@ def run_main(model_id, args, tb_dir, tb_args):
361392 decomp_attn = args .decomp_attn ,
362393 attn_spec = args .attn_spec ,
363394 input_mlir = args .input_mlir ,
395+ compare_vs_eager = args .compare_vs_torch ,
364396 )
365397 if args .compile_to in ["torch" , "mlir" ]:
366398 safe_name = utils .create_safe_name (
@@ -379,6 +411,8 @@ def run_main(model_id, args, tb_dir, tb_args):
379411 model_id ,
380412 args .output_csv ,
381413 args .num_iters ,
414+ golden ,
415+ baseline ,
382416 )
383417
384418 gc .collect ()
0 commit comments