44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , List , Optional , Tuple
7+ import functools
8+ from typing import Any , List , Optional , Sequence , Tuple
89
10+ import coremltools as ct
911import executorch
1012import executorch .backends .test .harness .stages as BaseStages
11-
1213import torch
14+
15+ from executorch .backends .apple .coreml .compiler import CoreMLBackend
1316from executorch .backends .apple .coreml .partition import CoreMLPartitioner
17+ from executorch .backends .apple .coreml .quantizer import CoreMLQuantizer
1418from executorch .backends .test .harness import Tester as TesterBase
1519from executorch .backends .test .harness .stages import StageType
1620from executorch .exir import EdgeCompileConfig
1721from executorch .exir .backend .partitioner import Partitioner
1822
1923
24+ def _create_default_partitioner (
25+ minimum_deployment_target : Any = ct .target .iOS15 ,
26+ ) -> CoreMLPartitioner :
27+ return CoreMLPartitioner (
28+ compile_specs = CoreMLBackend .generate_compile_specs (
29+ minimum_deployment_target = minimum_deployment_target
30+ )
31+ )
32+
33+
34+ def _get_static_int8_linear_qconfig ():
35+ return ct .optimize .torch .quantization .LinearQuantizerConfig (
36+ global_config = ct .optimize .torch .quantization .ModuleLinearQuantizerConfig (
37+ quantization_scheme = "symmetric" ,
38+ activation_dtype = torch .quint8 ,
39+ weight_dtype = torch .qint8 ,
40+ weight_per_channel = True ,
41+ )
42+ )
43+
44+
45+ class Quantize (BaseStages .Quantize ):
46+ def __init__ (
47+ self ,
48+ quantizer : Optional [CoreMLQuantizer ] = None ,
49+ quantization_config : Optional [Any ] = None ,
50+ calibrate : bool = True ,
51+ calibration_samples : Optional [Sequence [Any ]] = None ,
52+ is_qat : Optional [bool ] = False ,
53+ ):
54+ super ().__init__ (
55+ quantizer = quantizer
56+ or CoreMLQuantizer (
57+ quantization_config or _get_static_int8_linear_qconfig ()
58+ ),
59+ calibrate = calibrate ,
60+ calibration_samples = calibration_samples ,
61+ is_qat = is_qat ,
62+ )
63+
64+
2065class Partition (BaseStages .Partition ):
21- def __init__ (self , partitioner : Optional [Partitioner ] = None ):
66+ def __init__ (
67+ self ,
68+ partitioner : Optional [Partitioner ] = None ,
69+ minimum_deployment_target : Optional [Any ] = ct .target .iOS15 ,
70+ ):
2271 super ().__init__ (
23- partitioner = partitioner or CoreMLPartitioner ,
72+ partitioner = partitioner
73+ or _create_default_partitioner (minimum_deployment_target ),
2474 )
2575
2676
@@ -29,9 +79,12 @@ def __init__(
2979 self ,
3080 partitioners : Optional [List [Partitioner ]] = None ,
3181 edge_compile_config : Optional [EdgeCompileConfig ] = None ,
82+ minimum_deployment_target : Optional [Any ] = ct .target .iOS15 ,
3283 ):
3384 super ().__init__ (
34- default_partitioner_cls = CoreMLPartitioner ,
85+ default_partitioner_cls = lambda : _create_default_partitioner (
86+ minimum_deployment_target
87+ ),
3588 partitioners = partitioners ,
3689 edge_compile_config = edge_compile_config ,
3790 )
@@ -43,13 +96,20 @@ def __init__(
4396 module : torch .nn .Module ,
4497 example_inputs : Tuple [torch .Tensor ],
4598 dynamic_shapes : Optional [Tuple [Any ]] = None ,
99+ minimum_deployment_target : Optional [Any ] = ct .target .iOS15 ,
46100 ):
47101 # Specialize for XNNPACK
48102 stage_classes = (
49103 executorch .backends .test .harness .Tester .default_stage_classes ()
50104 | {
51- StageType .PARTITION : Partition ,
52- StageType .TO_EDGE_TRANSFORM_AND_LOWER : ToEdgeTransformAndLower ,
105+ StageType .QUANTIZE : Quantize ,
106+ StageType .PARTITION : functools .partial (
107+ Partition , minimum_deployment_target = minimum_deployment_target
108+ ),
109+ StageType .TO_EDGE_TRANSFORM_AND_LOWER : functools .partial (
110+ ToEdgeTransformAndLower ,
111+ minimum_deployment_target = minimum_deployment_target ,
112+ ),
53113 }
54114 )
55115
0 commit comments