diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index 00653bc925b..5d5523ba31d 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -35,9 +35,9 @@ def trace( decomp_table = torch.export.default_decompositions() # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any remove_decompositions(decomp_table, ops_to_keep) - program = torch.export.export_for_training( - model, inputs, strict=strict - ).run_decompositions(decomp_table) + program = torch.export.export(model, inputs, strict=strict).run_decompositions( + decomp_table + ) return program