@@ -148,9 +148,11 @@ def get_model_and_inputs(model_id, batch_size, tb_dir, tb_args, get_baseline=Fal
148148 return model_name , model , forward_args
149149
150150
151- '''
151+ """
152152Imports models from torchbench model tooling, exports them with turbine AOT, and does simple benchmarking.
153- '''
153+ """
154+
155+
154156@torch .no_grad ()
155157def benchmark_torchbench_model (
156158 model_id ,
@@ -199,7 +201,7 @@ def benchmark_torchbench_model(
199201 )
200202 return vmfb_path
201203
202- if compare_vs_eager :
204+ if compare_vs_eager :
203205 model_name , model , forward_args , golden , baseline = get_model_and_inputs (
204206 model_id , batch_size , tb_dir , tb_args , get_baseline = True
205207 )
@@ -316,13 +318,28 @@ def _run_iter(runner, inputs):
316318 res = runner .ctx .modules .compiled_torchbench_model ["main" ](* inputs )
317319 return res , time .time () - start
318320
321+
319322def do_compare (shark_results , shark_latency , golden_results , golden_latency ):
320- numerics_pass_fail = np .allclose (shark_results .to_host (), golden_results .clone ().cpu ().numpy (), rtol = 1e-4 , atol = 1e-4 )
323+ numerics_pass_fail = np .allclose (
324+ shark_results .to_host (),
325+ golden_results .clone ().cpu ().numpy (),
326+ rtol = 1e-4 ,
327+ atol = 1e-4 ,
328+ )
321329 speedup = golden_latency / shark_latency
322330 return speedup , numerics_pass_fail
323331
332+
324333def run_benchmark (
325- device , vmfb_path , weights_path , example_args , model_id , csv_path , iters , golden = None , baseline = None ,
334+ device ,
335+ vmfb_path ,
336+ weights_path ,
337+ example_args ,
338+ model_id ,
339+ csv_path ,
340+ iters ,
341+ golden = None ,
342+ baseline = None ,
326343):
327344 if "rocm" in device :
328345 device = "hip" + device .split ("rocm" )[- 1 ]
@@ -344,7 +361,13 @@ def run_benchmark(
344361 if os .path .exists (csv_path ):
345362 needs_header = False
346363 with open (csv_path , "a" ) as csvfile :
347- fieldnames = ["model" , "avg_latency" , "avg_iter_per_sec" , "speedup_over_eager" , "numerics" ]
364+ fieldnames = [
365+ "model" ,
366+ "avg_latency" ,
367+ "avg_iter_per_sec" ,
368+ "speedup_over_eager" ,
369+ "numerics" ,
370+ ]
348371 data = [
349372 {
350373 "model" : model_id ,
@@ -422,7 +445,7 @@ def run_main(model_id, args, tb_dir, tb_args):
422445 from turbine_models .custom_models .torchbench .cmd_opts import args , unknown
423446 import json
424447
425- torchbench_models_dict = json .load (args .model_list_json
448+ torchbench_models_dict = json .load (args .model_list_json )
426449 for list in args .model_lists :
427450 torchbench_models_dict = json .load (list )
428451 with open (args .models_json , "r" ) as f :
0 commit comments