4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , List , Optional , Tuple
7
+ from typing import Any , List , Optional , Sequence , Tuple
8
8
9
9
import executorch
10
10
import executorch .backends .test .harness .stages as BaseStages
11
11
12
12
import torch
13
13
from executorch .backends .qualcomm ._passes .qnn_pass_manager import QnnPassManager
14
14
from executorch .backends .qualcomm .partition .qnn_partitioner import QnnPartitioner
15
+ from executorch .backends .qualcomm .quantizer .quantizer import QnnQuantizer
15
16
from executorch .backends .qualcomm .utils .utils import (
16
17
generate_htp_compiler_spec ,
17
18
generate_qnn_executorch_compiler_spec ,
21
22
from executorch .backends .test .harness .stages import StageType
22
23
from executorch .exir import EdgeCompileConfig , to_edge_transform_and_lower
23
24
from executorch .exir .backend .partitioner import Partitioner
25
+ from torch .ao .quantization .quantize_pt2e import (
26
+ convert_pt2e ,
27
+ prepare_pt2e ,
28
+ prepare_qat_pt2e ,
29
+ )
24
30
from torch .export import ExportedProgram
25
31
26
32
33
+ class Quantize (BaseStages .Quantize ):
34
+ def __init__ (
35
+ self ,
36
+ quantizer : QnnQuantizer ,
37
+ quantization_config : Optional [Any ] = None ,
38
+ calibrate : bool = True ,
39
+ calibration_samples : Optional [Sequence [Any ]] = None ,
40
+ is_qat : Optional [bool ] = False ,
41
+ ):
42
+ super ().__init__ (
43
+ quantizer = quantizer ,
44
+ calibrate = calibrate ,
45
+ calibration_samples = calibration_samples ,
46
+ is_qat = is_qat ,
47
+ set_global = False ,
48
+ )
49
+
50
+
27
51
class Partition (BaseStages .Partition ):
28
52
def __init__ (self , partitioner : Optional [Partitioner ] = None ):
29
53
super ().__init__ (
@@ -37,8 +61,9 @@ def __init__(
37
61
partitioners : Optional [List [Partitioner ]] = None ,
38
62
edge_compile_config : Optional [EdgeCompileConfig ] = None ,
39
63
soc_model : str = "SM8650" ,
64
+ use_fp16 : bool = True ,
40
65
):
41
- backend_options = generate_htp_compiler_spec (use_fp16 = True )
66
+ backend_options = generate_htp_compiler_spec (use_fp16 = use_fp16 )
42
67
self .chipset = get_soc_to_chipset_map ()[soc_model ]
43
68
self .compiler_specs = generate_qnn_executorch_compiler_spec (
44
69
soc_model = self .chipset ,
@@ -73,15 +98,17 @@ def __init__(
73
98
module : torch .nn .Module ,
74
99
example_inputs : Tuple [torch .Tensor ],
75
100
dynamic_shapes : Optional [Tuple [Any ]] = None ,
101
+ use_fp16 : bool = True ,
76
102
):
103
+ def create_to_edge_transform_and_lower (* args , ** kwargs ):
104
+ kwargs ["use_fp16" ] = use_fp16
105
+ return ToEdgeTransformAndLower (* args , ** kwargs )
106
+
77
107
# Specialize for Qualcomm
78
- stage_classes = (
79
- executorch .backends .test .harness .Tester .default_stage_classes ()
80
- | {
81
- StageType .PARTITION : Partition ,
82
- StageType .TO_EDGE_TRANSFORM_AND_LOWER : ToEdgeTransformAndLower ,
83
- }
84
- )
108
+ stage_classes = executorch .backends .test .harness .Tester .default_stage_classes () | {
109
+ StageType .PARTITION : Partition ,
110
+ StageType .TO_EDGE_TRANSFORM_AND_LOWER : create_to_edge_transform_and_lower ,
111
+ }
85
112
86
113
super ().__init__ (
87
114
module = module ,
0 commit comments