9
9
import unittest
10
10
11
11
import torch
12
+ import torch ._export as export
12
13
import torchvision
13
14
from executorch import exir
14
15
from executorch .exir import EdgeCompileConfig
15
16
from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
16
17
from executorch .exir .passes .spec_prop_pass import SpecPropPass
17
18
from torch .ao .ns .fx .utils import compute_sqnr
18
19
from torch .ao .quantization import get_default_qconfig , QConfigMapping # @manual
20
+ from torch .ao .quantization .backend_config import get_executorch_backend_config
21
+ from torch .ao .quantization .qconfig import default_per_channel_symmetric_qnnpack_qconfig
19
22
from torch .ao .quantization .quantize_fx import convert_to_reference_fx , prepare_fx
20
- from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
23
+ from torch .ao .quantization .quantize_pt2e import (
24
+ _convert_to_reference_decomposed_fx ,
25
+ convert_pt2e ,
26
+ prepare_pt2e ,
27
+ )
21
28
22
29
from torch .ao .quantization .quantizer .xnnpack_quantizer import (
23
30
get_symmetric_quantization_config ,
@@ -37,7 +44,6 @@ class TestQuantization(unittest.TestCase):
37
44
APIs for now, but we plan to open source them in the future
38
45
"""
39
46
40
- @skipIfNoQNNPACK
41
47
def test_resnet (self ) -> None :
42
48
import copy
43
49
@@ -47,13 +53,7 @@ def test_resnet(self) -> None:
47
53
m = torchvision .models .resnet18 ().eval () # pyre-ignore[16]
48
54
m_copy = copy .deepcopy (m )
49
55
# program capture
50
- exported_program = exir .capture (m , example_inputs )
51
- # TODO: probably need to support exported_program.to_aten()
52
- m = exported_program .to_edge (
53
- exir .EdgeCompileConfig (
54
- _check_ir_validity = False ,
55
- ),
56
- ).graph_module
56
+ m = export .capture_pre_autograd_graph (m , copy .deepcopy (example_inputs ))
57
57
58
58
quantizer = XNNPACKQuantizer ()
59
59
operator_config = get_symmetric_quantization_config (is_per_channel = True )
@@ -64,22 +64,21 @@ def test_resnet(self) -> None:
64
64
)
65
65
after_prepare_result = m (* example_inputs )[0 ]
66
66
m = convert_pt2e (m )
67
- after_quant_result = m ( * example_inputs )[ 0 ]
67
+
68
68
# TODO: conv, conv_relu, linear delegation
69
69
# quantized ops to implement: add_relu
70
70
compile_config = EdgeCompileConfig (
71
71
passes = [QuantFusionPass (), SpecPropPass ()],
72
72
_check_ir_validity = False ,
73
73
)
74
74
m = exir .capture (m , example_inputs ).to_edge (config = compile_config )
75
+ after_quant_result = m (* example_inputs )[0 ]
75
76
FileCheck ().check (
76
77
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor"
77
- ).check (
78
- "executorch_exir_dialects_edge__ops_quantized_decomposed_add_relu_default"
79
78
).check (
80
79
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor"
81
80
).run (
82
- m .code
81
+ m .exported_program . graph_module . code
83
82
)
84
83
# after_quant_fusion_result = m(*example_inputs)[0]
85
84
@@ -92,11 +91,16 @@ def test_resnet(self) -> None:
92
91
# self.assertEqual(compute_sqnr(after_quant_fusion_result, after_to_executorch), torch.tensor(float("inf")))
93
92
94
93
# comparing with existing fx graph mode quantization reference flow
95
- qconfig = get_default_qconfig ( "qnnpack" )
94
+ qconfig = default_per_channel_symmetric_qnnpack_qconfig
96
95
qconfig_mapping = QConfigMapping ().set_global (qconfig )
97
- m_fx = prepare_fx (m_copy , qconfig_mapping , example_inputs )
96
+ backend_config = get_executorch_backend_config ()
97
+ m_fx = prepare_fx (
98
+ m_copy , qconfig_mapping , example_inputs , backend_config = backend_config
99
+ )
98
100
after_prepare_result_fx = m_fx (* example_inputs )
99
- m_fx = convert_to_reference_fx (m_fx ) # , backend_config=backend_config)
101
+ m_fx = _convert_to_reference_decomposed_fx (
102
+ m_fx , backend_config = backend_config
103
+ )
100
104
after_quant_result_fx = m_fx (* example_inputs )
101
105
102
106
# the result matches exactly after prepare
0 commit comments