5
5
# This source code is licensed under the BSD-style license found in the
6
6
# LICENSE file in the root directory of this source tree.
7
7
8
- import logging
9
8
import unittest
10
9
11
10
from typing import Tuple
12
11
13
12
import torch
13
+ from executorch .backends .arm .quantizer .arm_quantizer import (
14
+ ArmQuantizer ,
15
+ get_symmetric_quantization_config ,
16
+ )
14
17
from executorch .backends .arm .test import common
15
18
from executorch .backends .arm .test .tester .arm_tester import ArmTester
19
+ from executorch .backends .xnnpack .test .tester .tester import Quantize
16
20
from executorch .exir .backend .backend_details import CompileSpec
17
21
from parameterized import parameterized
18
22
19
- logger = logging .getLogger (__name__ )
20
- logger .setLevel (logging .INFO )
21
23
22
24
test_data_suite = [
23
25
# (test_name, test_data, [kernel_size, stride, padding])
@@ -69,13 +71,14 @@ def _test_avgpool2d_tosa_MI_pipeline(
69
71
def _test_avgpool2d_tosa_BI_pipeline (
70
72
self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
71
73
):
74
+ quantizer = ArmQuantizer ().set_io (get_symmetric_quantization_config ())
72
75
(
73
76
ArmTester (
74
77
module ,
75
78
example_inputs = test_data ,
76
79
compile_spec = common .get_tosa_compile_spec (permute_memory_to_nhwc = True ),
77
80
)
78
- .quantize ()
81
+ .quantize (Quantize ( quantizer , get_symmetric_quantization_config ()) )
79
82
.export ()
80
83
.check_count ({"torch.ops.aten.avg_pool2d.default" : 1 })
81
84
.check (["torch.ops.quantized_decomposed" ])
@@ -93,13 +96,14 @@ def _test_avgpool2d_tosa_ethos_BI_pipeline(
93
96
compile_spec : CompileSpec ,
94
97
test_data : Tuple [torch .tensor ],
95
98
):
99
+ quantizer = ArmQuantizer ().set_io (get_symmetric_quantization_config ())
96
100
(
97
101
ArmTester (
98
102
module ,
99
103
example_inputs = test_data ,
100
104
compile_spec = compile_spec ,
101
105
)
102
- .quantize ()
106
+ .quantize (Quantize ( quantizer , get_symmetric_quantization_config ()) )
103
107
.export ()
104
108
.check_count ({"torch.ops.aten.avg_pool2d.default" : 1 })
105
109
.check (["torch.ops.quantized_decomposed" ])
@@ -121,10 +125,7 @@ def test_avgpool2d_tosa_MI(
121
125
self .AvgPool2d (* model_params ), (test_data ,)
122
126
)
123
127
124
- # Expected to fail since ArmQuantizer cannot quantize a AvgPool2D layer
125
- # TODO(MLETORCH-93)
126
128
@parameterized .expand (test_data_suite )
127
- @unittest .expectedFailure
128
129
def test_avgpool2d_tosa_BI (
129
130
self ,
130
131
test_name : str ,
@@ -135,10 +136,7 @@ def test_avgpool2d_tosa_BI(
135
136
self .AvgPool2d (* model_params ), (test_data ,)
136
137
)
137
138
138
- # Expected to fail since ArmQuantizer cannot quantize a AvgPool2D layer
139
- # TODO(MLETORCH-93)
140
139
@parameterized .expand (test_data_suite )
141
- @unittest .expectedFailure
142
140
def test_avgpool2d_tosa_u55_BI (
143
141
self ,
144
142
test_name : str ,
@@ -152,7 +150,6 @@ def test_avgpool2d_tosa_u55_BI(
152
150
)
153
151
154
152
@parameterized .expand (test_data_suite )
155
- @unittest .expectedFailure
156
153
def test_avgpool2d_tosa_u85_BI (
157
154
self ,
158
155
test_name : str ,
0 commit comments