44import torch
55from executorch .devtools .backend_debug import get_delegation_info
66from executorch .exir ._warnings import experimental
7+ from executorch .exir .backend .backend_api import validation_disabled
78from executorch .exir .program import (
89 EdgeProgramManager ,
910 ExecutorchProgramManager ,
1011 to_edge_transform_and_lower ,
1112)
1213from executorch .exir .schema import Program
14+ from executorch .extension .export_util .utils import save_pte_program
1315from executorch .runtime import Runtime , Verification
1416from tabulate import tabulate
1517from torch import nn
1618from torch .ao .quantization import allow_exported_model_train_eval
19+ from torch .ao .quantization .quantizer .composable_quantizer import ComposableQuantizer
1720from torch .export import ExportedProgram
1821from torchao .quantization import quantize_
1922from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
@@ -145,15 +148,15 @@ def run(
145148 model ,
146149 self ._example_inputs_dict [method_name ][0 ],
147150 dynamic_shapes = dynamic_shapes ,
151+ strict = True ,
148152 )
149153
150154 # Apply pre-edge transform passes if available
151155 if self ._pre_edge_transform_passes is not None :
152- self . _exported_program [ method_name ] = (
153- self ._pre_edge_transform_passes (
156+ for pre_edge_transform_pass in self . _pre_edge_transform_passes :
157+ self ._exported_program [ method_name ] = pre_edge_transform_pass (
154158 self ._exported_program [method_name ]
155159 )
156- )
157160
158161 def get_artifacts (self ) -> Dict [str , ExportedProgram ]:
159162 """
@@ -210,13 +213,14 @@ def run(
210213 self ._constant_methods = transform_config .get ("constant_methods" , None )
211214
212215 # Process inputs
213- self ._edge_program_manager = to_edge_transform_and_lower (
214- self ._exported_program ,
215- partitioner = self ._partitioners ,
216- transform_passes = self ._transform_passes ,
217- constant_methods = self ._constant_methods ,
218- compile_config = self ._compile_config ,
219- )
216+ with validation_disabled ():
217+ self ._edge_program_manager = to_edge_transform_and_lower (
218+ self ._exported_program ,
219+ partitioner = self ._partitioners ,
220+ transform_passes = self ._transform_passes ,
221+ constant_methods = self ._constant_methods ,
222+ compile_config = self ._compile_config ,
223+ )
220224 self ._delegation_info = get_delegation_info (
221225 self ._edge_program_manager .exported_program ().graph_module
222226 )
@@ -345,8 +349,8 @@ class QuantizeStage(Stage):
345349 Optional stage: Perform post-training quantization on the model.
346350 """
347351
348- def __init__ (self , quantizer : Any ) -> None :
349- self ._quantizer = quantizer
352+ def __init__ (self , quantizers : Any ) -> None :
353+ self ._quantizers = quantizers
350354 self ._quantized_models : Dict [str , nn .Module ] = {}
351355 self ._model_dict : Dict [str , nn .Module ] = {}
352356 self ._exported_program_dict : Dict [str , ExportedProgram ] = {}
@@ -394,7 +398,8 @@ def run(
394398 model = exported_program .module ()
395399
396400 # Prepare the model for quantization
397- prepared_model = prepare_pt2e (model , self ._quantizer ) # type: ignore
401+ composed_quantizer = ComposableQuantizer (self ._quantizers )
402+ prepared_model = prepare_pt2e (model , composed_quantizer ) # type: ignore
398403
399404 # Allow the model to switch between train and eval modes
400405 allow_exported_model_train_eval (prepared_model )
@@ -546,9 +551,9 @@ def __init__(
546551
547552 # Create the quantize stage if a quantizer is provided
548553 if self ._export_recipe .quantization_recipe is not None :
549- quantizer = self ._export_recipe .quantization_recipe .get_quantizer ()
550- if quantizer is not None :
551- quantize_stage = QuantizeStage (quantizer = quantizer )
554+ quantizers = self ._export_recipe .quantization_recipe .get_quantizers ()
555+ if quantizers is not None :
556+ quantize_stage = QuantizeStage (quantizers = quantizers )
552557 self ._pipeline .append (quantize_stage )
553558
554559 # Create the edge transform and lower stage
@@ -661,6 +666,22 @@ def get_executorch_program(self) -> Program:
661666 )
662667 return self ._executorch_program_manager .executorch_program
663668
669+ def get_executorch_program_manager (self ) -> ExecutorchProgramManager :
670+ """
671+ Get the ExecutorchProgramManager.
672+
673+ Returns:
674+ The ExecutorchProgramManager
675+
676+ Raises:
677+ RuntimeError: If the executorch program manager is not initialized
678+ """
679+ if self ._executorch_program_manager is None :
680+ raise RuntimeError (
681+ "Executorch program manager is not initialized. Run export() first."
682+ )
683+ return self ._executorch_program_manager
684+
664685 def get_pte_buffer (self ) -> bytes :
665686 """
666687 Get the PTE buffer as bytes.
@@ -677,6 +698,20 @@ def get_pte_buffer(self) -> bytes:
677698 )
678699 return self ._executorch_program_manager .buffer
679700
701+ def save_to_pte (self , output_name : str ) -> None :
702+ """
703+ Save the model to a .pte file.
704+
705+ Args:
706+ output_name (Optional[str]): The name of the .pte file.
707+ """
708+ assert output_name , "Need a valid output name"
709+ if self ._executorch_program_manager is None :
710+ raise RuntimeError (
711+ "Executorch program manager is not initialized. Run export() first."
712+ )
713+ save_pte_program (self ._executorch_program_manager , output_name )
714+
680715 def get_example_input (
681716 self , method_name : str = "forward"
682717 ) -> Tuple [torch .Tensor , ...]:
0 commit comments