18
18
from executorch .examples .models .model_factory import EagerModelFactory
19
19
from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
20
20
from executorch .exir .schema import DelegateCall , Program
21
- from executorch .export import export , ExportRecipe , recipe_registry
21
+ from executorch .export import export , ExportRecipe , recipe_registry , StageType
22
22
from torch import nn
23
23
from torch .testing ._internal .common_quantization import TestHelperModules
24
+ from torchao .quantization .utils import compute_error
24
25
25
26
26
27
class TestXnnpackRecipes (unittest .TestCase ):
@@ -38,6 +39,29 @@ def check_fully_delegated(self, program: Program) -> None:
38
39
self .assertEqual (len (instructions ), 1 )
39
40
self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
40
41
42
+ # pyre-ignore
43
+ def _compare_eager_quantized_model_outputs (
44
+ self , session , example_inputs , atol : float
45
+ ) -> None :
46
+ """Utility to compare eager quantized model output with session output after xnnpack lowering"""
47
+ torch_export_stage_output = session .get_stage_artifacts ()[
48
+ StageType .TORCH_EXPORT
49
+ ]
50
+ eager_quantized_model = torch_export_stage_output .data ["forward" ].module ()
51
+ output = session .run_method ("forward" , example_inputs [0 ])[0 ]
52
+ expected = eager_quantized_model (* example_inputs [0 ])
53
+ Tester ._assert_outputs_equal (output , expected , atol = atol )
54
+
55
+ def _compare_eager_unquantized_model_outputs (
56
+ self , session , eager_unquantized_model , example_inputs , sqnr_threshold = 20
57
+ ):
58
+ """Utility to compare eager unquantized model output with session output using SQNR"""
59
+ quantized_output = session .run_method ("forward" , example_inputs [0 ])[0 ]
60
+ original_output = eager_unquantized_model (* example_inputs [0 ])
61
+ error = compute_error (original_output , quantized_output )
62
+ print (f"{ self ._testMethodName } - SQNR: { error } dB" )
63
+ self .assertTrue (error > sqnr_threshold )
64
+
41
65
def test_basic_recipe (self ) -> None :
42
66
m_eager = TestHelperModules .TwoLinearModule ().eval ()
43
67
example_inputs = [(torch .randn (9 , 8 ),)]
@@ -46,18 +70,13 @@ def test_basic_recipe(self) -> None:
46
70
example_inputs = example_inputs ,
47
71
export_recipe = ExportRecipe .get_recipe (XNNPackRecipeType .FP32 ),
48
72
)
49
- self .assertTrue (
50
- torch .allclose (
51
- session .run_method ("forward" , example_inputs [0 ])[0 ],
52
- m_eager (* example_inputs [0 ]),
53
- atol = 1e-3 ,
54
- )
55
- )
73
+ self ._compare_eager_quantized_model_outputs (session , example_inputs , 1e-3 )
56
74
self .check_fully_delegated (session .get_executorch_program ())
75
+ self ._compare_eager_unquantized_model_outputs (session , m_eager , example_inputs )
57
76
58
77
def test_int8_dynamic_quant_recipe (self ) -> None :
59
78
test_cases = [
60
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL ),
79
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL ),
61
80
]
62
81
63
82
for export_recipe in test_cases :
@@ -70,19 +89,18 @@ def test_int8_dynamic_quant_recipe(self) -> None:
70
89
example_inputs = example_inputs ,
71
90
export_recipe = export_recipe ,
72
91
)
73
- self .assertTrue (
74
- torch .allclose (
75
- session .run_method ("forward" , example_inputs [0 ])[0 ],
76
- m_eager (* example_inputs [0 ]),
77
- atol = 1e-1 ,
78
- )
92
+ self ._compare_eager_quantized_model_outputs (
93
+ session , example_inputs , 1e-1
79
94
)
80
95
self .check_fully_delegated (session .get_executorch_program ())
96
+ self ._compare_eager_unquantized_model_outputs (
97
+ session , m_eager , example_inputs
98
+ )
81
99
82
100
def test_int8_static_quant_recipe (self ) -> None :
83
101
test_cases = [
84
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_CHANNEL ),
85
- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_TENSOR ),
102
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL ),
103
+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR ),
86
104
]
87
105
88
106
for export_recipe in test_cases :
@@ -95,14 +113,13 @@ def test_int8_static_quant_recipe(self) -> None:
95
113
example_inputs = example_inputs ,
96
114
export_recipe = export_recipe ,
97
115
)
98
- self .assertTrue (
99
- torch .allclose (
100
- session .run_method ("forward" , example_inputs [0 ])[0 ],
101
- m_eager (* example_inputs [0 ]),
102
- atol = 1e-1 ,
103
- )
116
+ self ._compare_eager_quantized_model_outputs (
117
+ session , example_inputs , 1e-2
104
118
)
105
119
self .check_fully_delegated (session .get_executorch_program ())
120
+ self ._compare_eager_unquantized_model_outputs (
121
+ session , m_eager , example_inputs
122
+ )
106
123
107
124
def test_8a4w_recipe (self ) -> None :
108
125
class SimpleLinearModel (nn .Module ):
@@ -116,40 +133,36 @@ def forward(self, x) -> torch.Tensor:
116
133
117
134
test_cases = [
118
135
ExportRecipe .get_recipe (
119
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
136
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
120
137
),
121
138
ExportRecipe .get_recipe (
122
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
123
- group_size = 32 ,
139
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
140
+ group_size = 8 ,
124
141
),
125
142
]
126
143
127
144
for export_recipe in test_cases :
128
145
with self .subTest (export_recipe = export_recipe ):
129
- model = SimpleLinearModel ()
146
+ model = SimpleLinearModel (). eval ()
130
147
example_inputs = [(torch .randn (1 , 32 ),)]
131
148
session = export (
132
149
model = model ,
133
150
example_inputs = example_inputs ,
134
151
export_recipe = export_recipe ,
135
152
)
136
- self .assertTrue (
137
- torch .allclose (
138
- session .run_method ("forward" , example_inputs [0 ])[0 ],
139
- model (* example_inputs [0 ]),
140
- atol = 1e-2 ,
141
- )
142
- )
143
153
self .check_fully_delegated (session .get_executorch_program ())
154
+ self ._compare_eager_quantized_model_outputs (
155
+ session , example_inputs , 1e-3
156
+ )
144
157
145
158
def _get_recipe_for_quant_type (self , quant_type : QuantType ) -> XNNPackRecipeType :
146
159
# Map QuantType to corresponding recipe name.
147
160
if quant_type == QuantType .STATIC_PER_CHANNEL :
148
- return XNNPackRecipeType .INT8_STATIC_PER_CHANNEL
161
+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL
149
162
elif quant_type == QuantType .DYNAMIC_PER_CHANNEL :
150
- return XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL
163
+ return XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL
151
164
elif quant_type == QuantType .STATIC_PER_TENSOR :
152
- return XNNPackRecipeType .INT8_STATIC_PER_TENSOR
165
+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR
153
166
elif quant_type == QuantType .NONE :
154
167
return XNNPackRecipeType .FP32
155
168
else :
@@ -224,12 +237,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size(
224
237
225
238
# Should not raise any exception
226
239
recipe_w_default_group = provider .create_recipe (
227
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
240
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
228
241
)
229
242
self .assertIsNotNone (recipe_w_default_group )
230
243
231
244
recipe = provider .create_recipe (
232
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR , group_size = 64
245
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
246
+ group_size = 64 ,
233
247
)
234
248
self .assertIsNotNone (recipe )
235
249
@@ -240,7 +254,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size(
240
254
241
255
with self .assertRaises (ValueError ) as cm :
242
256
provider .create_recipe (
243
- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
257
+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
244
258
group_size = "32" , # String instead of int
245
259
)
246
260
0 commit comments