4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import logging
7
8
from abc import ABC , abstractmethod
8
9
from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
9
10
18
19
)
19
20
from executorch .exir .program ._program import _transform
20
21
from executorch .exir .schema import Program
22
+ from executorch .export .recipe import QuantizationRecipe
21
23
from executorch .extension .export_util .utils import save_pte_program
22
24
from executorch .runtime import Runtime , Verification
23
25
from tabulate import tabulate
26
28
from torch ._export .pass_base import PassType
27
29
from torch .export import ExportedProgram
28
30
from torchao .quantization import quantize_
29
- from torchao .quantization .pt2e import allow_exported_model_train_eval
30
31
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
31
32
32
33
from torchao .quantization .pt2e .quantizer import ComposableQuantizer
@@ -360,8 +361,8 @@ class QuantizeStage(Stage):
360
361
def __init__ (self , quantizers : Any ) -> None :
361
362
self ._quantizers = quantizers
362
363
self ._quantized_models : Dict [str , nn .Module ] = {}
364
+ self ._exported_programs : Dict [str , ExportedProgram ] = {}
363
365
self ._model_dict : Dict [str , nn .Module ] = {}
364
- self ._exported_program_dict : Dict [str , ExportedProgram ] = {}
365
366
self ._example_inputs_dict : Dict [str , List [tuple [torch .Tensor , ...]]] = {}
366
367
367
368
@property
@@ -370,20 +371,20 @@ def name(self) -> str:
370
371
371
372
def run (
372
373
self ,
373
- exported_program_data : Dict [str , Any ],
374
+ models : Dict [str , nn . Module ],
374
375
calibration_config : Optional [Dict [str , Any ]] = None ,
375
376
** kwargs ,
376
377
) -> None :
377
378
"""
378
- Perform post-training quantization on the exported program .
379
+ Perform post-training quantization on the model .
379
380
380
381
Args:
381
- exported_program_data : Dictionary containing exported programs
382
+ models : Dictionary containing models to quantize
382
383
calibration_config: Configuration containing example inputs for calibration
383
384
**kwargs: Additional keyword arguments (not used)
384
385
"""
385
386
# Store inputs
386
- self ._exported_program_dict = exported_program_data [ "exported_program" ]
387
+ self ._model_dict = models
387
388
388
389
# Initialize with empty dictionaries
389
390
self ._example_inputs_dict = {}
@@ -392,7 +393,7 @@ def run(
392
393
self ._example_inputs_dict = calibration_config .get ("example_inputs" , {})
393
394
394
395
# Process inputs
395
- for method_name , exported_program in self ._exported_program_dict .items ():
396
+ for method_name , model in self ._model_dict .items ():
396
397
# Check if method_name exists in example_inputs and has at least one element
397
398
if (
398
399
method_name not in self ._example_inputs_dict
@@ -402,23 +403,21 @@ def run(
402
403
f"Example inputs for method { method_name } not found or empty."
403
404
)
404
405
405
- # Get the module from the exported program
406
- model = exported_program .module ()
406
+ # Export the model for training to get a captured graph
407
+ inputs = self ._example_inputs_dict [method_name ][0 ]
408
+ captured_graph = torch .export .export (model , inputs , strict = True ).module ()
407
409
408
410
# Prepare the model for quantization
409
411
composed_quantizer = ComposableQuantizer (self ._quantizers )
410
- prepared_model = prepare_pt2e (model , composed_quantizer ) # type: ignore
411
-
412
- # Allow the model to switch between train and eval modes
413
- allow_exported_model_train_eval (prepared_model )
412
+ prepared_model = prepare_pt2e (captured_graph , composed_quantizer ) # type: ignore
414
413
415
414
# Calibrate the model with the provided calibration data
416
415
for calibration_input in self ._example_inputs_dict [method_name ]: # type: ignore
417
416
prepared_model (* calibration_input )
418
417
419
418
# Convert the prepared model to a quantized model
420
419
quantized_model = convert_pt2e (prepared_model )
421
- self ._quantized_models [method_name ] = quantized_model # type: ignore
420
+ self ._quantized_models [method_name ] = quantized_model
422
421
423
422
def get_artifacts (self ) -> Dict [str , nn .Module ]:
424
423
"""
@@ -541,29 +540,37 @@ def __init__(
541
540
self ._artifact_dir = artifact_dir
542
541
self ._export_recipe = export_recipe
543
542
543
+ self ._quant_recipe : Optional [QuantizationRecipe ] = (
544
+ self ._export_recipe .quantization_recipe
545
+ )
546
+
544
547
# Initialize pipeline as a list of stages
545
548
self ._pipeline = []
546
549
547
550
# Create the source transform stage if a quantization recipe is provided
548
- if self ._export_recipe . quantization_recipe is not None :
551
+ if self ._quant_recipe is not None and self . _quant_recipe . ao_base_config :
549
552
source_transform_stage = SourceTransformStage (
550
553
quantization_recipe = self ._export_recipe .quantization_recipe
551
554
)
552
555
self ._pipeline .append (source_transform_stage )
553
556
554
- # Create the export stage
555
- export_stage = ExportStage (
556
- pre_edge_transform_passes = self ._export_recipe .pre_edge_transform_passes
557
+ enable_quantize_stage = (
558
+ self ._quant_recipe is not None and self ._quant_recipe .quantizers
557
559
)
558
- self ._pipeline .append (export_stage )
559
560
560
561
# Create the quantize stage if a quantizer is provided
561
- if self . _export_recipe . quantization_recipe is not None :
562
- quantizers = self . _export_recipe . quantization_recipe . get_quantizers ()
563
- if quantizers is not None :
562
+ if enable_quantize_stage :
563
+ # pyre-ignore
564
+ if quantizers := self . _quant_recipe . quantizers :
564
565
quantize_stage = QuantizeStage (quantizers = quantizers )
565
566
self ._pipeline .append (quantize_stage )
566
567
568
+ # Create the export stage
569
+ export_stage = ExportStage (
570
+ pre_edge_transform_passes = self ._export_recipe .pre_edge_transform_passes ,
571
+ )
572
+ self ._pipeline .append (export_stage )
573
+
567
574
# Create the edge transform and lower stage
568
575
edge_transform_and_lower_stage = EdgeTransformAndLowerStage (
569
576
partitioners = self ._export_recipe .partitioners ,
@@ -597,16 +604,16 @@ def _run_pipeline(self) -> None:
597
604
# Process each stage in the pipeline
598
605
for stage in self ._pipeline :
599
606
stage_name = stage .name
607
+ logging .info (f"Executing stage: { stage_name } " )
600
608
# Configure inputs for the current stage
601
609
if stage_name == "source_transform" :
602
610
# Run the source transform stage
603
611
stage .run (self ._model , {})
604
612
self ._model = stage .get_artifacts ()
605
613
elif stage_name == "quantize" :
606
614
# Run the quantize stage
607
- exported_program_data = {"exported_program" : self ._exported_program }
608
615
config_params = {"example_inputs" : self ._example_inputs }
609
- stage .run (exported_program_data , config_params )
616
+ stage .run (self . _model , config_params )
610
617
self ._model = stage .get_artifacts ()
611
618
elif stage_name == "export" :
612
619
# Run the export stage
0 commit comments