44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import functools
78from typing import Any , List , Optional , Sequence , Tuple
89
910import coremltools as ct
1011import executorch
1112import executorch .backends .test .harness .stages as BaseStages
12- import functools
1313import torch
1414
1515from executorch .backends .apple .coreml .compiler import CoreMLBackend
2121from executorch .exir .backend .partitioner import Partitioner
2222
2323
24- def _get_static_int8_qconfig ():
24+ def _get_static_int8_linear_qconfig ():
2525 return ct .optimize .torch .quantization .LinearQuantizerConfig (
2626 global_config = ct .optimize .torch .quantization .ModuleLinearQuantizerConfig (
2727 quantization_scheme = "symmetric" ,
@@ -42,22 +42,23 @@ def __init__(
4242 is_qat : Optional [bool ] = False ,
4343 ):
4444 super ().__init__ (
45- quantizer = quantizer or CoreMLQuantizer (quantization_config or _get_static_int8_qconfig ()),
45+ quantizer = quantizer
46+ or CoreMLQuantizer (quantization_config or _get_static_int8_linear_qconfig ()),
4647 calibrate = calibrate ,
4748 calibration_samples = calibration_samples ,
4849 is_qat = is_qat ,
4950 )
5051
5152
52-
5353class Partition (BaseStages .Partition ):
5454 def __init__ (
55- self ,
55+ self ,
5656 partitioner : Optional [Partitioner ] = None ,
5757 minimum_deployment_target : Optional [Any ] = ct .target .iOS15 ,
5858 ):
5959 super ().__init__ (
60- partitioner = partitioner or CoreMLPartitioner (
60+ partitioner = partitioner
61+ or CoreMLPartitioner (
6162 compile_specs = CoreMLBackend .generate_compile_specs (
6263 minimum_deployment_target = minimum_deployment_target
6364 )
@@ -74,9 +75,9 @@ def __init__(
7475 ):
7576 super ().__init__ (
7677 default_partitioner_cls = lambda : CoreMLPartitioner (
77- compile_specs = CoreMLBackend .generate_compile_specs (
78+ compile_specs = CoreMLBackend .generate_compile_specs (
7879 minimum_deployment_target = minimum_deployment_target
79- )
80+ )
8081 ),
8182 partitioners = partitioners ,
8283 edge_compile_config = edge_compile_config ,
@@ -96,8 +97,13 @@ def __init__(
9697 executorch .backends .test .harness .Tester .default_stage_classes ()
9798 | {
9899 StageType .QUANTIZE : Quantize ,
99- StageType .PARTITION : functools .partial (Partition , minimum_deployment_target = minimum_deployment_target ),
100- StageType .TO_EDGE_TRANSFORM_AND_LOWER : functools .partial (ToEdgeTransformAndLower , minimum_deployment_target = minimum_deployment_target ),
100+ StageType .PARTITION : functools .partial (
101+ Partition , minimum_deployment_target = minimum_deployment_target
102+ ),
103+ StageType .TO_EDGE_TRANSFORM_AND_LOWER : functools .partial (
104+ ToEdgeTransformAndLower ,
105+ minimum_deployment_target = minimum_deployment_target ,
106+ ),
101107 }
102108 )
103109
0 commit comments