@@ -144,8 +144,7 @@ import torch
144144
145145from executorch.exir import EdgeCompileConfig, to_edge
146146from torch.nn.attention import sdpa_kernel, SDPBackend
147- from torch._export import capture_pre_autograd_graph
148- from torch.export import export
147+ from torch.export import export, export_for_training
149148
150149from model import GPT
151150
@@ -170,7 +169,7 @@ dynamic_shape = (
170169# Trace the model, converting it to a portable intermediate representation.
171170# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
172171with torch.nn.attention.sdpa_kernel([SDPBackend.MATH ]), torch.no_grad():
173- m = capture_pre_autograd_graph (model, example_inputs, dynamic_shapes = dynamic_shape)
172+ m = export_for_training (model, example_inputs, dynamic_shapes = dynamic_shape).module( )
174173 traced_model = export(m, example_inputs, dynamic_shapes = dynamic_shape)
175174
176175# Convert the model into a runnable ExecuTorch program.
@@ -462,7 +461,7 @@ from executorch.exir import EdgeCompileConfig, to_edge
462461import torch
463462from torch.export import export
464463from torch.nn.attention import sdpa_kernel, SDPBackend
465- from torch._export import capture_pre_autograd_graph
464+ from torch.export import export_for_training
466465
467466from model import GPT
468467
@@ -489,7 +488,7 @@ dynamic_shape = (
489488# Trace the model, converting it to a portable intermediate representation.
490489# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
491490with torch.nn.attention.sdpa_kernel([SDPBackend.MATH ]), torch.no_grad():
492- m = capture_pre_autograd_graph (model, example_inputs, dynamic_shapes = dynamic_shape)
491+ m = export_for_training (model, example_inputs, dynamic_shapes = dynamic_shape).module( )
493492 traced_model = export(m, example_inputs, dynamic_shapes = dynamic_shape)
494493
495494# Convert the model into a runnable ExecuTorch program.
@@ -635,7 +634,7 @@ xnnpack_quant_config = get_symmetric_quantization_config(
635634xnnpack_quantizer = XNNPACKQuantizer()
636635xnnpack_quantizer.set_global(xnnpack_quant_config)
637636
638- m = capture_pre_autograd_graph (model, example_inputs)
637+ m = export_for_training (model, example_inputs).module( )
639638
640639# Annotate the model for quantization. This prepares the model for calibration.
641640m = prepare_pt2e(m, xnnpack_quantizer)
0 commit comments