@@ -1060,6 +1060,8 @@ def maybe_mark_profile(*args, **kwargs):
10601060 frozen_model_iter_fn = export_nativert (model , example_inputs )
10611061 elif args .torchscript_jit_trace :
10621062 frozen_model_iter_fn = torchscript_jit_trace (model , example_inputs )
1063+ elif args .aot_precompile :
1064+ frozen_model_iter_fn = aot_precompile (model , example_inputs )
10631065 else :
10641066 if kwargs ["hf_llm" ]:
10651067 # If it's an llm, we want to optimize model.forward, and use
@@ -1495,6 +1497,37 @@ def opt_export(_, example_inputs):
14951497 return opt_export
14961498
14971499
1500+ def aot_precompile (model , example_inputs ):
1501+ example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
1502+
1503+ with tempfile .NamedTemporaryFile (suffix = ".pt" , delete = False ) as f :
1504+ save_path = f .name
1505+
1506+ with fresh_cache (), torch ._dynamo .config .patch ("enable_aot_compile" , True ):
1507+ compiled_fn = torch .compile (
1508+ model ,
1509+ fullgraph = True ,
1510+ options = {"guard_filter_fn" : lambda guards : [False for _ in guards ]},
1511+ ).forward .aot_compile ((example_args , example_kwargs ))
1512+
1513+ compiled_fn .save_compiled_function (save_path )
1514+
1515+ torch ._dynamo .reset ()
1516+ with open (save_path , "rb" ) as f :
1517+ load_start_time = time .perf_counter ()
1518+ loaded_fn = torch .compiler .load_compiled_function (f )
1519+ load_end_time = time .perf_counter ()
1520+ print (
1521+ f"AOT Precompile loading time: { load_end_time - load_start_time } seconds"
1522+ )
1523+
1524+ def opt_aot_precompile (_ , example_inputs , collect_outputs = False ):
1525+ example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
1526+ return loaded_fn (model , * example_args , ** example_kwargs )
1527+
1528+ return opt_aot_precompile
1529+
1530+
14981531def export_nativert (model , example_inputs ):
14991532 optimized = NativeRTCache .load (model , example_inputs )
15001533
@@ -2274,6 +2307,7 @@ def record_status(accuracy_status, dynamo_start_stats):
22742307 or self .args .export_aot_inductor
22752308 or self .args .export_nativert
22762309 or self .args .torchscript_jit_trace
2310+ or self .args .aot_precompile
22772311 ):
22782312 # apply export on module directly
22792313 # no need for n iterations
@@ -2729,6 +2763,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):
27292763 self .args .export_aot_inductor
27302764 or self .args .export_nativert
27312765 or self .args .torchscript_jit_trace
2766+ or self .args .aot_precompile
27322767 ):
27332768 optimized_model_iter_fn = optimize_ctx
27342769 else :
@@ -3505,6 +3540,11 @@ def get_example_inputs(self):
35053540 action = "store_true" ,
35063541 help = "Measure pass rate with Export+AOTInductor" ,
35073542 )
3543+ group .add_argument (
3544+ "--aot-precompile" ,
3545+ action = "store_true" ,
3546+ help = "Measure pass rate with AOT Precompile" ,
3547+ )
35083548 group .add_argument (
35093549 "--export-nativert" ,
35103550 action = "store_true" ,
@@ -3935,6 +3975,10 @@ def run(runner, args, original_dir=None):
39353975 optimize_ctx = export
39363976 experiment = speedup_experiment
39373977 output_filename = "export.csv"
3978+ elif args .aot_precompile :
3979+ optimize_ctx = aot_precompile
3980+ experiment = speedup_experiment
3981+ output_filename = "aot_precompile.csv"
39383982 elif args .export_nativert :
39393983 optimize_ctx = export_nativert
39403984 experiment = speedup_experiment
0 commit comments