44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , List , Optional , Tuple
7+ from typing import Any , List , Optional , Sequence , Tuple
88
99import executorch
1010import executorch .backends .test .harness .stages as BaseStages
1111
1212import torch
1313from executorch .backends .qualcomm ._passes .qnn_pass_manager import QnnPassManager
1414from executorch .backends .qualcomm .partition .qnn_partitioner import QnnPartitioner
15+ from executorch .backends .qualcomm .quantizer .quantizer import QnnQuantizer
1516from executorch .backends .qualcomm .utils .utils import (
1617 generate_htp_compiler_spec ,
1718 generate_qnn_executorch_compiler_spec ,
2122from executorch .backends .test .harness .stages import StageType
2223from executorch .exir import EdgeCompileConfig , to_edge_transform_and_lower
2324from executorch .exir .backend .partitioner import Partitioner
25+ from torch .ao .quantization .quantize_pt2e import (
26+ convert_pt2e ,
27+ prepare_pt2e ,
28+ prepare_qat_pt2e ,
29+ )
2430from torch .export import ExportedProgram
2531
2632
33+ class Quantize (BaseStages .Quantize ):
34+ def __init__ (
35+ self ,
36+ quantizer : QnnQuantizer ,
37+ quantization_config : Optional [Any ] = None ,
38+ calibrate : bool = True ,
39+ calibration_samples : Optional [Sequence [Any ]] = None ,
40+ is_qat : Optional [bool ] = False ,
41+ ):
42+ super ().__init__ (
43+ quantizer = quantizer ,
44+ calibrate = calibrate ,
45+ calibration_samples = calibration_samples ,
46+ is_qat = is_qat ,
47+ set_global = False ,
48+ )
49+
50+
2751class Partition (BaseStages .Partition ):
2852 def __init__ (self , partitioner : Optional [Partitioner ] = None ):
2953 super ().__init__ (
@@ -37,8 +61,9 @@ def __init__(
3761 partitioners : Optional [List [Partitioner ]] = None ,
3862 edge_compile_config : Optional [EdgeCompileConfig ] = None ,
3963 soc_model : str = "SM8650" ,
64+ use_fp16 : bool = True ,
4065 ):
41- backend_options = generate_htp_compiler_spec (use_fp16 = True )
66+ backend_options = generate_htp_compiler_spec (use_fp16 = use_fp16 )
4267 self .chipset = get_soc_to_chipset_map ()[soc_model ]
4368 self .compiler_specs = generate_qnn_executorch_compiler_spec (
4469 soc_model = self .chipset ,
@@ -73,15 +98,17 @@ def __init__(
7398 module : torch .nn .Module ,
7499 example_inputs : Tuple [torch .Tensor ],
75100 dynamic_shapes : Optional [Tuple [Any ]] = None ,
101+ use_fp16 : bool = True ,
76102 ):
103+ def create_to_edge_transform_and_lower (* args , ** kwargs ):
104+ kwargs ["use_fp16" ] = use_fp16
105+ return ToEdgeTransformAndLower (* args , ** kwargs )
106+
77107 # Specialize for Qualcomm
78- stage_classes = (
79- executorch .backends .test .harness .Tester .default_stage_classes ()
80- | {
81- StageType .PARTITION : Partition ,
82- StageType .TO_EDGE_TRANSFORM_AND_LOWER : ToEdgeTransformAndLower ,
83- }
84- )
108+ stage_classes = executorch .backends .test .harness .Tester .default_stage_classes () | {
109+ StageType .PARTITION : Partition ,
110+ StageType .TO_EDGE_TRANSFORM_AND_LOWER : create_to_edge_transform_and_lower ,
111+ }
85112
86113 super ().__init__ (
87114 module = module ,
0 commit comments