Skip to content

Commit 102b788

Browse files
yiming0416pytorchmergebot
authored andcommitted
Add option to run AOT Precompile in benchmark (pytorch#164906)
Use the existing benchmark infra to get some signals for AOT precompile pass rate on OSS models. Here we also measure and log the loading time. ``` python ./benchmarks/dynamo/huggingface.py --accuracy --inference --aot-precompile python ./benchmarks/dynamo/timm_models.py --accuracy --inference --aot-precompile python ./benchmarks/dynamo/torchbench.py --accuracy --inference --aot-precompile ``` Pull Request resolved: pytorch#164906 Approved by: https://github.com/zhxchen17
1 parent 382d04a commit 102b788

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

benchmarks/dynamo/common.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,8 @@ def maybe_mark_profile(*args, **kwargs):
10601060
frozen_model_iter_fn = export_nativert(model, example_inputs)
10611061
elif args.torchscript_jit_trace:
10621062
frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs)
1063+
elif args.aot_precompile:
1064+
frozen_model_iter_fn = aot_precompile(model, example_inputs)
10631065
else:
10641066
if kwargs["hf_llm"]:
10651067
# If it's an llm, we want to optimize model.forward, and use
@@ -1495,6 +1497,37 @@ def opt_export(_, example_inputs):
14951497
return opt_export
14961498

14971499

1500+
def aot_precompile(model, example_inputs):
1501+
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1502+
1503+
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
1504+
save_path = f.name
1505+
1506+
with fresh_cache(), torch._dynamo.config.patch("enable_aot_compile", True):
1507+
compiled_fn = torch.compile(
1508+
model,
1509+
fullgraph=True,
1510+
options={"guard_filter_fn": lambda guards: [False for _ in guards]},
1511+
).forward.aot_compile((example_args, example_kwargs))
1512+
1513+
compiled_fn.save_compiled_function(save_path)
1514+
1515+
torch._dynamo.reset()
1516+
with open(save_path, "rb") as f:
1517+
load_start_time = time.perf_counter()
1518+
loaded_fn = torch.compiler.load_compiled_function(f)
1519+
load_end_time = time.perf_counter()
1520+
print(
1521+
f"AOT Precompile loading time: {load_end_time - load_start_time} seconds"
1522+
)
1523+
1524+
def opt_aot_precompile(_, example_inputs, collect_outputs=False):
1525+
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1526+
return loaded_fn(model, *example_args, **example_kwargs)
1527+
1528+
return opt_aot_precompile
1529+
1530+
14981531
def export_nativert(model, example_inputs):
14991532
optimized = NativeRTCache.load(model, example_inputs)
15001533

@@ -2274,6 +2307,7 @@ def record_status(accuracy_status, dynamo_start_stats):
22742307
or self.args.export_aot_inductor
22752308
or self.args.export_nativert
22762309
or self.args.torchscript_jit_trace
2310+
or self.args.aot_precompile
22772311
):
22782312
# apply export on module directly
22792313
# no need for n iterations
@@ -2729,6 +2763,7 @@ def warmup(fn, model, example_inputs, mode, niters=5):
27292763
self.args.export_aot_inductor
27302764
or self.args.export_nativert
27312765
or self.args.torchscript_jit_trace
2766+
or self.args.aot_precompile
27322767
):
27332768
optimized_model_iter_fn = optimize_ctx
27342769
else:
@@ -3505,6 +3540,11 @@ def get_example_inputs(self):
35053540
action="store_true",
35063541
help="Measure pass rate with Export+AOTInductor",
35073542
)
3543+
group.add_argument(
3544+
"--aot-precompile",
3545+
action="store_true",
3546+
help="Measure pass rate with AOT Precompile",
3547+
)
35083548
group.add_argument(
35093549
"--export-nativert",
35103550
action="store_true",
@@ -3935,6 +3975,10 @@ def run(runner, args, original_dir=None):
39353975
optimize_ctx = export
39363976
experiment = speedup_experiment
39373977
output_filename = "export.csv"
3978+
elif args.aot_precompile:
3979+
optimize_ctx = aot_precompile
3980+
experiment = speedup_experiment
3981+
output_filename = "aot_precompile.csv"
39383982
elif args.export_nativert:
39393983
optimize_ctx = export_nativert
39403984
experiment = speedup_experiment

0 commit comments

Comments
 (0)