1919from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
2020from executorch .exir .schema import DelegateCall , Program
2121from executorch .export import export , ExportRecipe , recipe_registry
22+ from export .types import StageType
2223from torch import nn
2324from torch .testing ._internal .common_quantization import TestHelperModules
25+ from torchao .quantization .utils import compute_error
2426
2527
2628class TestXnnpackRecipes (unittest .TestCase ):
@@ -38,6 +40,29 @@ def check_fully_delegated(self, program: Program) -> None:
3840 self .assertEqual (len (instructions ), 1 )
3941 self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
4042
43+ # pyre-ignore
44+ def _compare_eager_quantized_model_outputs (
45+ self , session , example_inputs , atol : float
46+ ) -> None :
47+ """Utility to compare eager quantized model output with session output after xnnpack lowering"""
48+ torch_export_stage_output = session .get_stage_artifacts ()[
49+ StageType .TORCH_EXPORT
50+ ]
51+ eager_quantized_model = torch_export_stage_output .data ["forward" ].module ()
52+ output = session .run_method ("forward" , example_inputs [0 ])[0 ]
53+ expected = eager_quantized_model (* example_inputs [0 ])
54+ Tester ._assert_outputs_equal (output , expected , atol = atol )
55+
56+ def _compare_eager_unquantized_model_outputs (
57+ self , session , eager_unquantized_model , example_inputs , sqnr_threshold = 20
58+ ):
59+ """Utility to compare eager unquantized model output with session output using SQNR"""
60+ quantized_output = session .run_method ("forward" , example_inputs [0 ])[0 ]
61+ original_output = eager_unquantized_model (* example_inputs [0 ])
62+ error = compute_error (original_output , quantized_output )
63+ print (f"{ self ._testMethodName } - SQNR: { error } dB" )
64+ self .assertTrue (error > sqnr_threshold )
65+
4166 def test_basic_recipe (self ) -> None :
4267 m_eager = TestHelperModules .TwoLinearModule ().eval ()
4368 example_inputs = [(torch .randn (9 , 8 ),)]
@@ -46,18 +71,13 @@ def test_basic_recipe(self) -> None:
4671 example_inputs = example_inputs ,
4772 export_recipe = ExportRecipe .get_recipe (XNNPackRecipeType .FP32 ),
4873 )
49- self .assertTrue (
50- torch .allclose (
51- session .run_method ("forward" , example_inputs [0 ])[0 ],
52- m_eager (* example_inputs [0 ]),
53- atol = 1e-3 ,
54- )
55- )
74+ self ._compare_eager_quantized_model_outputs (session , example_inputs , 1e-3 )
5675 self .check_fully_delegated (session .get_executorch_program ())
76+ self ._compare_eager_unquantized_model_outputs (session , m_eager , example_inputs )
5777
5878 def test_int8_dynamic_quant_recipe (self ) -> None :
5979 test_cases = [
60- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL ),
80+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL ),
6181 ]
6282
6383 for export_recipe in test_cases :
@@ -70,19 +90,18 @@ def test_int8_dynamic_quant_recipe(self) -> None:
7090 example_inputs = example_inputs ,
7191 export_recipe = export_recipe ,
7292 )
73- self .assertTrue (
74- torch .allclose (
75- session .run_method ("forward" , example_inputs [0 ])[0 ],
76- m_eager (* example_inputs [0 ]),
77- atol = 1e-1 ,
78- )
93+ self ._compare_eager_quantized_model_outputs (
94+ session , example_inputs , 1e-1
7995 )
8096 self .check_fully_delegated (session .get_executorch_program ())
97+ self ._compare_eager_unquantized_model_outputs (
98+ session , m_eager , example_inputs
99+ )
81100
82101 def test_int8_static_quant_recipe (self ) -> None :
83102 test_cases = [
84- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_CHANNEL ),
85- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_TENSOR ),
103+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL ),
104+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR ),
86105 ]
87106
88107 for export_recipe in test_cases :
@@ -95,14 +114,13 @@ def test_int8_static_quant_recipe(self) -> None:
95114 example_inputs = example_inputs ,
96115 export_recipe = export_recipe ,
97116 )
98- self .assertTrue (
99- torch .allclose (
100- session .run_method ("forward" , example_inputs [0 ])[0 ],
101- m_eager (* example_inputs [0 ]),
102- atol = 1e-1 ,
103- )
117+ self ._compare_eager_quantized_model_outputs (
118+ session , example_inputs , 1e-2
104119 )
105120 self .check_fully_delegated (session .get_executorch_program ())
121+ self ._compare_eager_unquantized_model_outputs (
122+ session , m_eager , example_inputs
123+ )
106124
107125 def test_8a4w_recipe (self ) -> None :
108126 class SimpleLinearModel (nn .Module ):
@@ -116,10 +134,10 @@ def forward(self, x) -> torch.Tensor:
116134
117135 test_cases = [
118136 ExportRecipe .get_recipe (
119- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
137+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
120138 ),
121139 ExportRecipe .get_recipe (
122- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
140+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
123141 group_size = 32 ,
124142 ),
125143 ]
@@ -133,23 +151,22 @@ def forward(self, x) -> torch.Tensor:
133151 example_inputs = example_inputs ,
134152 export_recipe = export_recipe ,
135153 )
136- self .assertTrue (
137- torch .allclose (
138- session .run_method ("forward" , example_inputs [0 ])[0 ],
139- model (* example_inputs [0 ]),
140- atol = 1e-2 ,
141- )
142- )
143154 self .check_fully_delegated (session .get_executorch_program ())
155+ self ._compare_eager_quantized_model_outputs (
156+ session , example_inputs , 1e-3
157+ )
158+ self ._compare_eager_unquantized_model_outputs (
159+ session , model , example_inputs , sqnr_threshold = 15
160+ )
144161
145162 def _get_recipe_for_quant_type (self , quant_type : QuantType ) -> XNNPackRecipeType :
146163 # Map QuantType to corresponding recipe name.
147164 if quant_type == QuantType .STATIC_PER_CHANNEL :
148- return XNNPackRecipeType .INT8_STATIC_PER_CHANNEL
165+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL
149166 elif quant_type == QuantType .DYNAMIC_PER_CHANNEL :
150- return XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL
167+ return XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL
151168 elif quant_type == QuantType .STATIC_PER_TENSOR :
152- return XNNPackRecipeType .INT8_STATIC_PER_TENSOR
169+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR
153170 elif quant_type == QuantType .NONE :
154171 return XNNPackRecipeType .FP32
155172 else :
@@ -224,12 +241,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size(
224241
225242 # Should not raise any exception
226243 recipe_w_default_group = provider .create_recipe (
227- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
244+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
228245 )
229246 self .assertIsNotNone (recipe_w_default_group )
230247
231248 recipe = provider .create_recipe (
232- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR , group_size = 64
249+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
250+ group_size = 64 ,
233251 )
234252 self .assertIsNotNone (recipe )
235253
@@ -240,7 +258,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size(
240258
241259 with self .assertRaises (ValueError ) as cm :
242260 provider .create_recipe (
243- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
261+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
244262 group_size = "32" , # String instead of int
245263 )
246264
0 commit comments