From 727be1a713d5261e4ba8d3f7f98b589d82a75c6c Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 27 Sep 2024 13:28:49 -0700 Subject: [PATCH] Migrate from capture_pre_autograd_graph to torch.export.export_for_training Summary: As titled. The `capture_pre_autograd_graph` API is deprecated. Differential Revision: D63541800 --- backends/cadence/aot/compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index e1494f8d20d..fe8fc721245 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -30,7 +30,6 @@ ) from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -58,7 +57,7 @@ def convert_pt2( """ # Export with dynamo - model_gm = capture_pre_autograd_graph(model, inputs) + model_gm = torch.export.export_for_training(model, inputs).module() if model_gm_has_SDPA(model_gm): # pyre-fixme[6] # Decompose SDPA