19
19
from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
20
20
from executorch .exir .schema import DelegateCall , Program
21
21
from executorch .export import export , ExportRecipe , recipe_registry
22
+ from export .types import StageType
22
23
from torch import nn
23
24
from torch .testing ._internal .common_quantization import TestHelperModules
25
+ from torchao .quantization .utils import compute_error
24
26
25
27
26
28
class TestXnnpackRecipes (unittest .TestCase ):
@@ -38,6 +40,29 @@ def check_fully_delegated(self, program: Program) -> None:
38
40
self .assertEqual (len (instructions ), 1 )
39
41
self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
40
42
43
+ # pyre-ignore
44
+ def _compare_eager_quantized_model_outputs (
45
+ self , session , example_inputs , atol : float
46
+ ) -> None :
47
+ """Utility to compare eager quantized model output with session output after xnnpack lowering"""
48
+ torch_export_stage_output = session .get_stage_artifacts ()[
49
+ StageType .TORCH_EXPORT
50
+ ]
51
+ eager_quantized_model = torch_export_stage_output .data ["forward" ].module ()
52
+ output = session .run_method ("forward" , example_inputs [0 ])[0 ]
53
+ expected = eager_quantized_model (* example_inputs [0 ])
54
+ Tester ._assert_outputs_equal (output , expected , atol = atol )
55
+
56
+ def _compare_eager_unquantized_model_outputs (
57
+ self , session , eager_unquantized_model , example_inputs , sqnr_threshold = 20
58
+ ):
59
+ """Utility to compare eager unquantized model output with session output using SQNR"""
60
+ quantized_output = session .run_method ("forward" , example_inputs [0 ])[0 ]
61
+ original_output = eager_unquantized_model (* example_inputs [0 ])
62
+ error = compute_error (original_output , quantized_output )
63
+ print (f"{ self ._testMethodName } - SQNR: { error } dB" )
64
+ self .assertTrue (error > sqnr_threshold )
65
+
41
66
def test_basic_recipe (self ) -> None :
42
67
m_eager = TestHelperModules .TwoLinearModule ().eval ()
43
68
example_inputs = [(torch .randn (9 , 8 ),)]
@@ -46,18 +71,13 @@ def test_basic_recipe(self) -> None:
46
71
example_inputs = example_inputs ,
47
72
export_recipe = ExportRecipe .get_recipe (XNNPackRecipeType .FP32 ),
48
73
)
49
- self .assertTrue (
50
- torch .allclose (
51
- session .run_method ("forward" , example_inputs [0 ])[0 ],
52
- m_eager (* example_inputs [0 ]),
53
- atol = 1e-3 ,
54
- )
55
- )
74
+ self ._compare_eager_quantized_model_outputs (session , example_inputs , 1e-3 )
56
75
self .check_fully_delegated (session .get_executorch_program ())
76
+ self ._compare_eager_unquantized_model_outputs (session , m_eager , example_inputs )
57
77
58
78
def test_int8_dynamic_quant_recipe (self ) -> None :
59
79
test_cases = [
60
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL ),
80
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL ),
61
81
]
62
82
63
83
for export_recipe in test_cases :
@@ -70,19 +90,18 @@ def test_int8_dynamic_quant_recipe(self) -> None:
70
90
example_inputs = example_inputs ,
71
91
export_recipe = export_recipe ,
72
92
)
73
- self .assertTrue (
74
- torch .allclose (
75
- session .run_method ("forward" , example_inputs [0 ])[0 ],
76
- m_eager (* example_inputs [0 ]),
77
- atol = 1e-1 ,
78
- )
93
+ self ._compare_eager_quantized_model_outputs (
94
+ session , example_inputs , 1e-1
79
95
)
80
96
self .check_fully_delegated (session .get_executorch_program ())
97
+ self ._compare_eager_unquantized_model_outputs (
98
+ session , m_eager , example_inputs
99
+ )
81
100
82
101
def test_int8_static_quant_recipe (self ) -> None :
83
102
test_cases = [
84
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_CHANNEL ),
85
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_TENSOR ),
103
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL ),
104
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR ),
86
105
]
87
106
88
107
for export_recipe in test_cases :
@@ -95,14 +114,13 @@ def test_int8_static_quant_recipe(self) -> None:
95
114
example_inputs = example_inputs ,
96
115
export_recipe = export_recipe ,
97
116
)
98
- self .assertTrue (
99
- torch .allclose (
100
- session .run_method ("forward" , example_inputs [0 ])[0 ],
101
- m_eager (* example_inputs [0 ]),
102
- atol = 1e-1 ,
103
- )
117
+ self ._compare_eager_quantized_model_outputs (
118
+ session , example_inputs , 1e-2
104
119
)
105
120
self .check_fully_delegated (session .get_executorch_program ())
121
+ self ._compare_eager_unquantized_model_outputs (
122
+ session , m_eager , example_inputs
123
+ )
106
124
107
125
def test_8a4w_recipe (self ) -> None :
108
126
class SimpleLinearModel (nn .Module ):
@@ -116,10 +134,10 @@ def forward(self, x) -> torch.Tensor:
116
134
117
135
test_cases = [
118
136
ExportRecipe .get_recipe (
119
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
137
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
120
138
),
121
139
ExportRecipe .get_recipe (
122
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
140
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
123
141
group_size = 32 ,
124
142
),
125
143
]
@@ -133,23 +151,22 @@ def forward(self, x) -> torch.Tensor:
133
151
example_inputs = example_inputs ,
134
152
export_recipe = export_recipe ,
135
153
)
136
- self .assertTrue (
137
- torch .allclose (
138
- session .run_method ("forward" , example_inputs [0 ])[0 ],
139
- model (* example_inputs [0 ]),
140
- atol = 1e-2 ,
141
- )
142
- )
143
154
self .check_fully_delegated (session .get_executorch_program ())
155
+ self ._compare_eager_quantized_model_outputs (
156
+ session , example_inputs , 1e-3
157
+ )
158
+ self ._compare_eager_unquantized_model_outputs (
159
+ session , model , example_inputs , sqnr_threshold = 15
160
+ )
144
161
145
162
def _get_recipe_for_quant_type (self , quant_type : QuantType ) -> XNNPackRecipeType :
146
163
# Map QuantType to corresponding recipe name.
147
164
if quant_type == QuantType .STATIC_PER_CHANNEL :
148
- return XNNPackRecipeType .INT8_STATIC_PER_CHANNEL
165
+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL
149
166
elif quant_type == QuantType .DYNAMIC_PER_CHANNEL :
150
- return XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL
167
+ return XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL
151
168
elif quant_type == QuantType .STATIC_PER_TENSOR :
152
- return XNNPackRecipeType .INT8_STATIC_PER_TENSOR
169
+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR
153
170
elif quant_type == QuantType .NONE :
154
171
return XNNPackRecipeType .FP32
155
172
else :
@@ -224,12 +241,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size(
224
241
225
242
# Should not raise any exception
226
243
recipe_w_default_group = provider .create_recipe (
227
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
244
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
228
245
)
229
246
self .assertIsNotNone (recipe_w_default_group )
230
247
231
248
recipe = provider .create_recipe (
232
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR , group_size = 64
249
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
250
+ group_size = 64 ,
233
251
)
234
252
self .assertIsNotNone (recipe )
235
253
@@ -240,7 +258,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size(
240
258
241
259
with self .assertRaises (ValueError ) as cm :
242
260
provider .create_recipe (
243
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
261
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
244
262
group_size = "32" , # String instead of int
245
263
)
246
264
0 commit comments