4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , List , Optional , Tuple
7
+ from typing import Any , List , Optional , Sequence , Tuple
8
8
9
9
import executorch
10
10
import executorch .backends .test .harness .stages as BaseStages
11
11
12
12
import torch
13
13
from executorch .backends .qualcomm ._passes .qnn_pass_manager import QnnPassManager
14
14
from executorch .backends .qualcomm .partition .qnn_partitioner import QnnPartitioner
15
+ from executorch .backends .qualcomm .quantizer .quantizer import QnnQuantizer
15
16
from executorch .backends .qualcomm .utils .utils import (
16
17
generate_htp_compiler_spec ,
17
18
generate_qnn_executorch_compiler_spec ,
24
25
from torch .export import ExportedProgram
25
26
26
27
28
+ class Quantize (BaseStages .Quantize ):
29
+ def __init__ (
30
+ self ,
31
+ quantizer : QnnQuantizer ,
32
+ quantization_config : Optional [Any ] = None ,
33
+ calibrate : bool = True ,
34
+ calibration_samples : Optional [Sequence [Any ]] = None ,
35
+ is_qat : Optional [bool ] = False ,
36
+ ):
37
+ super ().__init__ (
38
+ quantizer = quantizer ,
39
+ calibrate = calibrate ,
40
+ calibration_samples = calibration_samples ,
41
+ is_qat = is_qat ,
42
+ set_global = False ,
43
+ )
44
+
45
+
27
46
class Partition (BaseStages .Partition ):
28
47
def __init__ (self , partitioner : Optional [Partitioner ] = None ):
29
48
super ().__init__ (
@@ -37,8 +56,9 @@ def __init__(
37
56
partitioners : Optional [List [Partitioner ]] = None ,
38
57
edge_compile_config : Optional [EdgeCompileConfig ] = None ,
39
58
soc_model : str = "SM8650" ,
59
+ use_fp16 : bool = True ,
40
60
):
41
- backend_options = generate_htp_compiler_spec (use_fp16 = True )
61
+ backend_options = generate_htp_compiler_spec (use_fp16 = use_fp16 )
42
62
self .chipset = get_soc_to_chipset_map ()[soc_model ]
43
63
self .compiler_specs = generate_qnn_executorch_compiler_spec (
44
64
soc_model = self .chipset ,
@@ -73,15 +93,17 @@ def __init__(
73
93
module : torch .nn .Module ,
74
94
example_inputs : Tuple [torch .Tensor ],
75
95
dynamic_shapes : Optional [Tuple [Any ]] = None ,
96
+ use_fp16 : bool = True ,
76
97
):
98
+ def create_to_edge_transform_and_lower (* args , ** kwargs ):
99
+ kwargs ["use_fp16" ] = use_fp16
100
+ return ToEdgeTransformAndLower (* args , ** kwargs )
101
+
77
102
# Specialize for Qualcomm
78
- stage_classes = (
79
- executorch .backends .test .harness .Tester .default_stage_classes ()
80
- | {
81
- StageType .PARTITION : Partition ,
82
- StageType .TO_EDGE_TRANSFORM_AND_LOWER : ToEdgeTransformAndLower ,
83
- }
84
- )
103
+ stage_classes = executorch .backends .test .harness .Tester .default_stage_classes () | {
104
+ StageType .PARTITION : Partition ,
105
+ StageType .TO_EDGE_TRANSFORM_AND_LOWER : create_to_edge_transform_and_lower ,
106
+ }
85
107
86
108
super ().__init__ (
87
109
module = module ,
0 commit comments