66
77# pyre-strict
88
9+ import logging
910import unittest
1011
1112import torch
1920from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
2021from executorch .exir .schema import DelegateCall , Program
2122from executorch .export import export , ExportRecipe , recipe_registry
23+ from export .types import StageType
2224from torch import nn
2325from torch .testing ._internal .common_quantization import TestHelperModules
26+ from torchao .quantization .utils import compute_error
2427
2528
2629class TestXnnpackRecipes (unittest .TestCase ):
@@ -38,6 +41,29 @@ def check_fully_delegated(self, program: Program) -> None:
3841 self .assertEqual (len (instructions ), 1 )
3942 self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
4043
44+ # pyre-ignore
45+ def _compare_eager_quantized_model_outputs (
46+ self , session , example_inputs , atol : float
47+ ) -> None :
48+ """Utility to compare eager quantized model output with session output after xnnpack lowering"""
49+ torch_export_stage_output = session .get_stage_artifacts ()[
50+ StageType .TORCH_EXPORT
51+ ]
52+ eager_quantized_model = torch_export_stage_output .data ["forward" ].module ()
53+ output = session .run_method ("forward" , example_inputs [0 ])[0 ]
54+ expected = eager_quantized_model (* example_inputs [0 ])
55+ Tester ._assert_outputs_equal (output , expected , atol = atol )
56+
57+ def _compare_eager_unquantized_model_outputs (
58+ self , session , eager_unquantized_model , example_inputs , sqnr_threshold = 20
59+ ):
60+ """Utility to compare eager unquantized model output with session output using SQNR"""
61+ quantized_output = session .run_method ("forward" , example_inputs [0 ])[0 ]
62+ original_output = eager_unquantized_model (* example_inputs [0 ])
63+ error = compute_error (original_output , quantized_output )
64+ print (f"{ self ._testMethodName } - SQNR: { error } dB" )
65+ self .assertTrue (error > sqnr_threshold )
66+
4167 def test_basic_recipe (self ) -> None :
4268 m_eager = TestHelperModules .TwoLinearModule ().eval ()
4369 example_inputs = [(torch .randn (9 , 8 ),)]
@@ -46,18 +72,13 @@ def test_basic_recipe(self) -> None:
4672 example_inputs = example_inputs ,
4773 export_recipe = ExportRecipe .get_recipe (XNNPackRecipeType .FP32 ),
4874 )
49- self .assertTrue (
50- torch .allclose (
51- session .run_method ("forward" , example_inputs [0 ])[0 ],
52- m_eager (* example_inputs [0 ]),
53- atol = 1e-3 ,
54- )
55- )
75+ self ._compare_eager_quantized_model_outputs (session , example_inputs , 1e-3 )
5676 self .check_fully_delegated (session .get_executorch_program ())
77+ self ._compare_eager_unquantized_model_outputs (session , m_eager , example_inputs )
5778
5879 def test_int8_dynamic_quant_recipe (self ) -> None :
5980 test_cases = [
60- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL ),
81+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL ),
6182 ]
6283
6384 for export_recipe in test_cases :
@@ -70,19 +91,18 @@ def test_int8_dynamic_quant_recipe(self) -> None:
7091 example_inputs = example_inputs ,
7192 export_recipe = export_recipe ,
7293 )
73- self .assertTrue (
74- torch .allclose (
75- session .run_method ("forward" , example_inputs [0 ])[0 ],
76- m_eager (* example_inputs [0 ]),
77- atol = 1e-1 ,
78- )
94+ self ._compare_eager_quantized_model_outputs (
95+ session , example_inputs , 1e-1
7996 )
8097 self .check_fully_delegated (session .get_executorch_program ())
98+ self ._compare_eager_unquantized_model_outputs (
99+ session , m_eager , example_inputs
100+ )
81101
82102 def test_int8_static_quant_recipe (self ) -> None :
83103 test_cases = [
84- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_CHANNEL ),
85- ExportRecipe .get_recipe (XNNPackRecipeType .INT8_STATIC_PER_TENSOR ),
104+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL ),
105+ ExportRecipe .get_recipe (XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR ),
86106 ]
87107
88108 for export_recipe in test_cases :
@@ -95,14 +115,13 @@ def test_int8_static_quant_recipe(self) -> None:
95115 example_inputs = example_inputs ,
96116 export_recipe = export_recipe ,
97117 )
98- self .assertTrue (
99- torch .allclose (
100- session .run_method ("forward" , example_inputs [0 ])[0 ],
101- m_eager (* example_inputs [0 ]),
102- atol = 1e-1 ,
103- )
118+ self ._compare_eager_quantized_model_outputs (
119+ session , example_inputs , 1e-2
104120 )
105121 self .check_fully_delegated (session .get_executorch_program ())
122+ self ._compare_eager_unquantized_model_outputs (
123+ session , m_eager , example_inputs
124+ )
106125
107126 def test_8a4w_recipe (self ) -> None :
108127 class SimpleLinearModel (nn .Module ):
@@ -116,10 +135,10 @@ def forward(self, x) -> torch.Tensor:
116135
117136 test_cases = [
118137 ExportRecipe .get_recipe (
119- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
138+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL ,
120139 ),
121140 ExportRecipe .get_recipe (
122- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
141+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
123142 group_size = 32 ,
124143 ),
125144 ]
@@ -133,23 +152,22 @@ def forward(self, x) -> torch.Tensor:
133152 example_inputs = example_inputs ,
134153 export_recipe = export_recipe ,
135154 )
136- self .assertTrue (
137- torch .allclose (
138- session .run_method ("forward" , example_inputs [0 ])[0 ],
139- model (* example_inputs [0 ]),
140- atol = 1e-2 ,
141- )
142- )
143155 self .check_fully_delegated (session .get_executorch_program ())
156+ self ._compare_eager_quantized_model_outputs (
157+ session , example_inputs , 1e-3
158+ )
159+ self ._compare_eager_unquantized_model_outputs (
160+ session , model , example_inputs , sqnr_threshold = 15
161+ )
144162
145163 def _get_recipe_for_quant_type (self , quant_type : QuantType ) -> XNNPackRecipeType :
146164 # Map QuantType to corresponding recipe name.
147165 if quant_type == QuantType .STATIC_PER_CHANNEL :
148- return XNNPackRecipeType .INT8_STATIC_PER_CHANNEL
166+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_CHANNEL
149167 elif quant_type == QuantType .DYNAMIC_PER_CHANNEL :
150- return XNNPackRecipeType .INT8_DYNAMIC_PER_CHANNEL
168+ return XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL
151169 elif quant_type == QuantType .STATIC_PER_TENSOR :
152- return XNNPackRecipeType .INT8_STATIC_PER_TENSOR
170+ return XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR
153171 elif quant_type == QuantType .NONE :
154172 return XNNPackRecipeType .FP32
155173 else :
@@ -224,12 +242,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size(
224242
225243 # Should not raise any exception
226244 recipe_w_default_group = provider .create_recipe (
227- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
245+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
228246 )
229247 self .assertIsNotNone (recipe_w_default_group )
230248
231249 recipe = provider .create_recipe (
232- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR , group_size = 64
250+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
251+ group_size = 64 ,
233252 )
234253 self .assertIsNotNone (recipe )
235254
@@ -240,7 +259,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size(
240259
241260 with self .assertRaises (ValueError ) as cm :
242261 provider .create_recipe (
243- XNNPackRecipeType .INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
262+ XNNPackRecipeType .TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ,
244263 group_size = "32" , # String instead of int
245264 )
246265
0 commit comments