44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , List , Optional , Tuple
7+ from typing import Any , List , Optional , Sequence , Tuple
88
99import executorch
1010import executorch .backends .test .harness .stages as BaseStages
1111
1212import torch
1313from executorch .backends .qualcomm ._passes .qnn_pass_manager import QnnPassManager
1414from executorch .backends .qualcomm .partition .qnn_partitioner import QnnPartitioner
15+ from executorch .backends .qualcomm .quantizer .quantizer import QnnQuantizer
1516from executorch .backends .qualcomm .utils .utils import (
1617 generate_htp_compiler_spec ,
1718 generate_qnn_executorch_compiler_spec ,
2425from torch .export import ExportedProgram
2526
2627
28+ class Quantize (BaseStages .Quantize ):
29+ def __init__ (
30+ self ,
31+ quantizer : QnnQuantizer ,
32+ quantization_config : Optional [Any ] = None ,
33+ calibrate : bool = True ,
34+ calibration_samples : Optional [Sequence [Any ]] = None ,
35+ is_qat : Optional [bool ] = False ,
36+ ):
37+ super ().__init__ (
38+ quantizer = quantizer ,
39+ calibrate = calibrate ,
40+ calibration_samples = calibration_samples ,
41+ is_qat = is_qat ,
42+ set_global = False ,
43+ )
44+
45+
2746class Partition (BaseStages .Partition ):
2847 def __init__ (self , partitioner : Optional [Partitioner ] = None ):
2948 super ().__init__ (
@@ -37,8 +56,9 @@ def __init__(
3756 partitioners : Optional [List [Partitioner ]] = None ,
3857 edge_compile_config : Optional [EdgeCompileConfig ] = None ,
3958 soc_model : str = "SM8650" ,
59+ use_fp16 : bool = True ,
4060 ):
41- backend_options = generate_htp_compiler_spec (use_fp16 = True )
61+ backend_options = generate_htp_compiler_spec (use_fp16 = use_fp16 )
4262 self .chipset = get_soc_to_chipset_map ()[soc_model ]
4363 self .compiler_specs = generate_qnn_executorch_compiler_spec (
4464 soc_model = self .chipset ,
@@ -73,15 +93,17 @@ def __init__(
7393 module : torch .nn .Module ,
7494 example_inputs : Tuple [torch .Tensor ],
7595 dynamic_shapes : Optional [Tuple [Any ]] = None ,
96+ use_fp16 : bool = True ,
7697 ):
98+ def create_to_edge_transform_and_lower (* args , ** kwargs ):
99+ kwargs ["use_fp16" ] = use_fp16
100+ return ToEdgeTransformAndLower (* args , ** kwargs )
101+
77102 # Specialize for Qualcomm
78- stage_classes = (
79- executorch .backends .test .harness .Tester .default_stage_classes ()
80- | {
81- StageType .PARTITION : Partition ,
82- StageType .TO_EDGE_TRANSFORM_AND_LOWER : ToEdgeTransformAndLower ,
83- }
84- )
103+ stage_classes = executorch .backends .test .harness .Tester .default_stage_classes () | {
104+ StageType .PARTITION : Partition ,
105+ StageType .TO_EDGE_TRANSFORM_AND_LOWER : create_to_edge_transform_and_lower ,
106+ }
85107
86108 super ().__init__ (
87109 module = module ,
0 commit comments