1313import contextlib
1414import logging
1515from enum import Enum
16- from typing import Any , Callable , Dict , List , Optional , Tuple
16+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1717from unittest .mock import patch
1818
1919import torch
3535
3636from executorch .extension .llm .export .export_passes import RemoveRedundantTransposes
3737from pytorch_tokenizers import get_tokenizer
38- from torch .ao .quantization .quantizer import Quantizer
39- from torch .ao .quantization .quantizer .composable_quantizer import ComposableQuantizer
38+ from torch .ao .quantization .quantizer import TorchQuantizer
39+ from torch .ao .quantization .quantizer .composable_quantizer import (
40+ TorchComposableQuantizer ,
41+ )
42+
4043from torch .export import export_for_training , ExportedProgram
4144from torch .nn .attention import SDPBackend
4245from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
46+ from torchao .quantization .pt2e .quantizer import ComposableQuantizer , Quantizer
4347from torchao .utils import unwrap_tensor_subclass
4448
4549FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -350,7 +354,9 @@ def calibrate_template(
350354 print (f"{ task } : { res } " )
351355 logging .info ("Calibration finish..." )
352356
353- def pt2e_quantize (self , quantizers : Optional [List [Quantizer ]]) -> "LLMEdgeManager" :
357+ def pt2e_quantize (
358+ self , quantizers : Optional [List [Union [Quantizer , TorchQuantizer ]]]
359+ ) -> "LLMEdgeManager" :
354360 """
355361 Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
356362 Args:
@@ -367,7 +373,12 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
367373 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
368374 if self .verbose :
369375 logging .info (f"Applied quantizers: { quantizers } " )
370- composed_quantizer = ComposableQuantizer (quantizers )
376+
377+ if any (isinstance (q , Quantizer ) for q in quantizers ):
378+ composed_quantizer = ComposableQuantizer (quantizers )
379+ else :
380+ composed_quantizer = TorchComposableQuantizer (quantizers )
381+
371382 assert (
372383 self .pre_autograd_graph_module is not None
373384 ), "Please run export() first"
0 commit comments