44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import logging
78from abc import ABC , abstractmethod
89from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
910
1819)
1920from executorch .exir .program ._program import _transform
2021from executorch .exir .schema import Program
22+ from executorch .export .recipe import QuantizationRecipe
2123from executorch .extension .export_util .utils import save_pte_program
2224from executorch .runtime import Runtime , Verification
2325from tabulate import tabulate
2628from torch ._export .pass_base import PassType
2729from torch .export import ExportedProgram
2830from torchao .quantization import quantize_
29- from torchao .quantization .pt2e import allow_exported_model_train_eval
3031from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
3132
3233from torchao .quantization .pt2e .quantizer import ComposableQuantizer
@@ -360,8 +361,8 @@ class QuantizeStage(Stage):
360361 def __init__ (self , quantizers : Any ) -> None :
361362 self ._quantizers = quantizers
362363 self ._quantized_models : Dict [str , nn .Module ] = {}
364+ self ._exported_programs : Dict [str , ExportedProgram ] = {}
363365 self ._model_dict : Dict [str , nn .Module ] = {}
364- self ._exported_program_dict : Dict [str , ExportedProgram ] = {}
365366 self ._example_inputs_dict : Dict [str , List [tuple [torch .Tensor , ...]]] = {}
366367
367368 @property
@@ -370,20 +371,20 @@ def name(self) -> str:
370371
371372 def run (
372373 self ,
373- exported_program_data : Dict [str , Any ],
374+ models : Dict [str , nn . Module ],
374375 calibration_config : Optional [Dict [str , Any ]] = None ,
375376 ** kwargs ,
376377 ) -> None :
377378 """
378- Perform post-training quantization on the exported program .
379+ Perform post-training quantization on the model .
379380
380381 Args:
381- exported_program_data : Dictionary containing exported programs
382+ models : Dictionary containing models to quantize
382383 calibration_config: Configuration containing example inputs for calibration
383384 **kwargs: Additional keyword arguments (not used)
384385 """
385386 # Store inputs
386- self ._exported_program_dict = exported_program_data [ "exported_program" ]
387+ self ._model_dict = models
387388
388389 # Initialize with empty dictionaries
389390 self ._example_inputs_dict = {}
@@ -392,7 +393,7 @@ def run(
392393 self ._example_inputs_dict = calibration_config .get ("example_inputs" , {})
393394
394395 # Process inputs
395- for method_name , exported_program in self ._exported_program_dict .items ():
396+ for method_name , model in self ._model_dict .items ():
396397 # Check if method_name exists in example_inputs and has at least one element
397398 if (
398399 method_name not in self ._example_inputs_dict
@@ -402,23 +403,21 @@ def run(
402403 f"Example inputs for method { method_name } not found or empty."
403404 )
404405
405- # Get the module from the exported program
406- model = exported_program .module ()
406+ # Export the model for training to get a captured graph
407+ inputs = self ._example_inputs_dict [method_name ][0 ]
408+ captured_graph = torch .export .export (model , inputs , strict = True ).module ()
407409
408410 # Prepare the model for quantization
409411 composed_quantizer = ComposableQuantizer (self ._quantizers )
410- prepared_model = prepare_pt2e (model , composed_quantizer ) # type: ignore
411-
412- # Allow the model to switch between train and eval modes
413- allow_exported_model_train_eval (prepared_model )
412+ prepared_model = prepare_pt2e (captured_graph , composed_quantizer ) # type: ignore
414413
415414 # Calibrate the model with the provided calibration data
416415 for calibration_input in self ._example_inputs_dict [method_name ]: # type: ignore
417416 prepared_model (* calibration_input )
418417
419418 # Convert the prepared model to a quantized model
420419 quantized_model = convert_pt2e (prepared_model )
421- self ._quantized_models [method_name ] = quantized_model # type: ignore
420+ self ._quantized_models [method_name ] = quantized_model
422421
423422 def get_artifacts (self ) -> Dict [str , nn .Module ]:
424423 """
@@ -541,29 +540,37 @@ def __init__(
541540 self ._artifact_dir = artifact_dir
542541 self ._export_recipe = export_recipe
543542
543+ self ._quant_recipe : Optional [QuantizationRecipe ] = (
544+ self ._export_recipe .quantization_recipe
545+ )
546+
544547 # Initialize pipeline as a list of stages
545548 self ._pipeline = []
546549
547550 # Create the source transform stage if a quantization recipe is provided
548- if self ._export_recipe . quantization_recipe is not None :
551+ if self ._quant_recipe is not None and self . _quant_recipe . ao_base_config :
549552 source_transform_stage = SourceTransformStage (
550553 quantization_recipe = self ._export_recipe .quantization_recipe
551554 )
552555 self ._pipeline .append (source_transform_stage )
553556
554- # Create the export stage
555- export_stage = ExportStage (
556- pre_edge_transform_passes = self ._export_recipe .pre_edge_transform_passes
557+ enable_quantize_stage = (
558+ self ._quant_recipe is not None and self ._quant_recipe .quantizers
557559 )
558- self ._pipeline .append (export_stage )
559560
560561 # Create the quantize stage if a quantizer is provided
561- if self . _export_recipe . quantization_recipe is not None :
562- quantizers = self . _export_recipe . quantization_recipe . get_quantizers ()
563- if quantizers is not None :
562+ if enable_quantize_stage :
563+ # pyre-ignore
564+ if quantizers := self . _quant_recipe . quantizers :
564565 quantize_stage = QuantizeStage (quantizers = quantizers )
565566 self ._pipeline .append (quantize_stage )
566567
568+ # Create the export stage
569+ export_stage = ExportStage (
570+ pre_edge_transform_passes = self ._export_recipe .pre_edge_transform_passes ,
571+ )
572+ self ._pipeline .append (export_stage )
573+
567574 # Create the edge transform and lower stage
568575 edge_transform_and_lower_stage = EdgeTransformAndLowerStage (
569576 partitioners = self ._export_recipe .partitioners ,
@@ -597,16 +604,16 @@ def _run_pipeline(self) -> None:
597604 # Process each stage in the pipeline
598605 for stage in self ._pipeline :
599606 stage_name = stage .name
607+ logging .info (f"Executing stage: { stage_name } " )
600608 # Configure inputs for the current stage
601609 if stage_name == "source_transform" :
602610 # Run the source transform stage
603611 stage .run (self ._model , {})
604612 self ._model = stage .get_artifacts ()
605613 elif stage_name == "quantize" :
606614 # Run the quantize stage
607- exported_program_data = {"exported_program" : self ._exported_program }
608615 config_params = {"example_inputs" : self ._example_inputs }
609- stage .run (exported_program_data , config_params )
616+ stage .run (self . _model , config_params )
610617 self ._model = stage .get_artifacts ()
611618 elif stage_name == "export" :
612619 # Run the export stage
0 commit comments