19
19
from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
20
20
from executorch .exir .schema import DelegateCall , Program
21
21
from executorch .export import export , ExportRecipe , recipe_registry
22
+ from export .types import StageType
22
23
from torch import nn
23
24
from torch .testing ._internal .common_quantization import TestHelperModules
25
+ from torchao .quantization .utils import compute_error
24
26
25
27
26
28
class TestXnnpackRecipes (unittest .TestCase ):
@@ -38,6 +40,28 @@ def check_fully_delegated(self, program: Program) -> None:
38
40
self .assertEqual (len (instructions ), 1 )
39
41
self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
40
42
43
+ # pyre-ignore
44
+ def _compare_eager_quantized_model_outputs (
45
+ self , session , example_inputs , atol : float
46
+ ) -> None :
47
+ """Utility to compare eager quantized model output with session output after xnnpack lowering"""
48
+ torch_export_stage_output = session .get_stage_artifacts ()[
49
+ StageType .TORCH_EXPORT
50
+ ]
51
+ eager_quantized_model = torch_export_stage_output .data ["forward" ].module ()
52
+ output = session .run_method ("forward" , example_inputs [0 ])[0 ]
53
+ expected = eager_quantized_model (* example_inputs [0 ])
54
+ Tester ._assert_outputs_equal (output , expected , atol = atol )
55
+
56
+ def _compare_eager_unquantized_model_outputs (
57
+ self , session , eager_unquantized_model , example_inputs , sqnr_threshold = 20
58
+ ):
59
+ """Utility to compare eager unquantized model output with session output using SQNR"""
60
+ quantized_output = session .run_method ("forward" , example_inputs [0 ])[0 ]
61
+ original_output = eager_unquantized_model (* example_inputs [0 ])
62
+ error = compute_error (original_output , quantized_output )
63
+ self .assertTrue (error > sqnr_threshold )
64
+
41
65
def test_basic_recipe (self ) -> None :
42
66
m_eager = TestHelperModules .TwoLinearModule ().eval ()
43
67
example_inputs = [(torch .randn (9 , 8 ),)]
@@ -46,18 +70,13 @@ def test_basic_recipe(self) -> None:
46
70
example_inputs = example_inputs ,
47
71
export_recipe = ExportRecipe .get_recipe (XNNPackRecipeType .FP32 ),
48
72
)
49
- self .assertTrue (
50
- torch .allclose (
51
- session .run_method ("forward" , example_inputs [0 ])[0 ],
52
- m_eager (* example_inputs [0 ]),
53
- atol = 1e-3 ,
54
- )
55
- )
73
+ self ._compare_eager_quantized_model_outputs (session , example_inputs , 1e-3 )
56
74
self .check_fully_delegated (session .get_executorch_program ())
75
+ self ._compare_eager_unquantized_model_outputs (session , m_eager , example_inputs )
57
76
58
77
def test_int8_dynamic_quant_recipe (self ) -> None :
59
78
test_cases = [
60
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL ),
79
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL ),
61
80
]
62
81
63
82
for export_recipe in test_cases :
@@ -70,19 +89,18 @@ def test_int8_dynamic_quant_recipe(self) -> None:
70
89
example_inputs = example_inputs ,
71
90
export_recipe = export_recipe ,
72
91
)
73
- self .assertTrue (
74
- torch .allclose (
75
- session .run_method ("forward" , example_inputs [0 ])[0 ],
76
- m_eager (* example_inputs [0 ]),
77
- atol = 1e-1 ,
78
- )
92
+ self ._compare_eager_quantized_model_outputs (
93
+ session , example_inputs , 1e-2
79
94
)
80
95
self .check_fully_delegated (session .get_executorch_program ())
96
+ self ._compare_eager_unquantized_model_outputs (
97
+ session , m_eager , example_inputs
98
+ )
81
99
82
100
def test_int8_static_quant_recipe (self ) -> None :
83
101
test_cases = [
84
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_CHANNEL ),
85
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_TENSOR ),
102
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL ),
103
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR ),
86
104
]
87
105
88
106
for export_recipe in test_cases :
@@ -95,14 +113,13 @@ def test_int8_static_quant_recipe(self) -> None:
95
113
example_inputs = example_inputs ,
96
114
export_recipe = export_recipe ,
97
115
)
98
- self .assertTrue (
99
- torch .allclose (
100
- session .run_method ("forward" , example_inputs [0 ])[0 ],
101
- m_eager (* example_inputs [0 ]),
102
- atol = 1e-1 ,
103
- )
116
+ self ._compare_eager_quantized_model_outputs (
117
+ session , example_inputs , 1e-3
104
118
)
105
119
self .check_fully_delegated (session .get_executorch_program ())
120
+ self ._compare_eager_unquantized_model_outputs (
121
+ session , m_eager , example_inputs
122
+ )
106
123
107
124
def test_8a4w_recipe (self ) -> None :
108
125
class SimpleLinearModel (nn .Module ):
@@ -116,10 +133,10 @@ def forward(self, x) -> torch.Tensor:
116
133
117
134
test_cases = [
118
135
ExportRecipe .get_recipe (
119
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
136
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
120
137
),
121
138
ExportRecipe .get_recipe (
122
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
139
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
123
140
group_size = 32 ,
124
141
),
125
142
]
@@ -133,23 +150,22 @@ def forward(self, x) -> torch.Tensor:
133
150
example_inputs = example_inputs ,
134
151
export_recipe = export_recipe ,
135
152
)
136
- self .assertTrue (
137
- torch .allclose (
138
- session .run_method ("forward" , example_inputs [0 ])[0 ],
139
- model (* example_inputs [0 ]),
140
- atol = 1e-2 ,
141
- )
142
- )
143
153
self .check_fully_delegated (session .get_executorch_program ())
154
+ self ._compare_eager_quantized_model_outputs (
155
+ session , example_inputs , 1e-3
156
+ )
157
+ self ._compare_eager_unquantized_model_outputs (
158
+ session , model , example_inputs
159
+ )
144
160
145
161
def _get_recipe_for_quant_type (self , quant_type : QuantType ) -> XNNPackRecipeType :
146
162
# Map QuantType to corresponding recipe name.
147
163
if quant_type == QuantType .STATIC_PER_CHANNEL :
148
- return XNNPackRecipeType .INT8_STATIC_PER_CHANNEL
164
+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL
149
165
elif quant_type == QuantType .DYNAMIC_PER_CHANNEL :
150
- return XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL
166
+ return XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL
151
167
elif quant_type == QuantType .STATIC_PER_TENSOR :
152
- return XNNPackRecipeType .INT8_STATIC_PER_TENSOR
168
+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR
153
169
elif quant_type == QuantType .NONE :
154
170
return XNNPackRecipeType .FP32
155
171
else :
@@ -224,12 +240,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size(
224
240
225
241
# Should not raise any exception
226
242
recipe_w_default_group = provider .create_recipe (
227
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
243
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
228
244
)
229
245
self .assertIsNotNone (recipe_w_default_group )
230
246
231
247
recipe = provider .create_recipe (
232
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR , group_size = 64
248
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
249
+ group_size = 64 ,
233
250
)
234
251
self .assertIsNotNone (recipe )
235
252
@@ -240,7 +257,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size(
240
257
241
258
with self .assertRaises (ValueError ) as cm :
242
259
provider .create_recipe (
243
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
260
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
244
261
group_size = "32" , # String instead of int
245
262
)
246
263
0 commit comments