66
77# pyre-strict
88
9+ import logging
10+ import os
911import unittest
12+ from typing import List , Optional , Tuple
1013
1114import torch
1215from executorch .backends .xnnpack .recipes .xnnpack_recipe_provider import (
1821from executorch .examples .models .model_factory import EagerModelFactory
1922from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
2023from executorch .exir .schema import DelegateCall , Program
21- from executorch .export import export , ExportRecipe , recipe_registry , StageType
22- from torch import nn
24+ from executorch .export import (
25+ export ,
26+ ExportRecipe ,
27+ ExportSession ,
28+ recipe_registry ,
29+ StageType ,
30+ )
31+ from torch import nn , Tensor
32+ from torch .testing import FileCheck
2333from torch .testing ._internal .common_quantization import TestHelperModules
2434from torchao .quantization .utils import compute_error
2535
@@ -39,9 +49,12 @@ def check_fully_delegated(self, program: Program) -> None:
3949 self .assertEqual (len (instructions ), 1 )
4050 self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
4151
42- # pyre-ignore
4352 def _compare_eager_quantized_model_outputs (
44- self , session , example_inputs , atol : float
53+ self ,
54+ # pyre-ignore[11]
55+ session : ExportSession ,
56+ example_inputs : List [Tuple [Tensor ]],
57+ atol : float ,
4558 ) -> None :
4659 """Utility to compare eager quantized model output with session output after xnnpack lowering"""
4760 torch_export_stage_output = session .get_stage_artifacts ()[
@@ -53,8 +66,12 @@ def _compare_eager_quantized_model_outputs(
5366 Tester ._assert_outputs_equal (output , expected , atol = atol )
5467
5568 def _compare_eager_unquantized_model_outputs (
56- self , session , eager_unquantized_model , example_inputs , sqnr_threshold = 20
57- ):
69+ self ,
70+ session : ExportSession ,
71+ eager_unquantized_model : nn .Module ,
72+ example_inputs : List [Tuple [Tensor ]],
73+ sqnr_threshold : int = 20 ,
74+ ) -> None :
5875 """Utility to compare eager unquantized model output with session output using SQNR"""
5976 quantized_output = session .run_method ("forward" , example_inputs [0 ])[0 ]
6077 original_output = eager_unquantized_model (* example_inputs [0 ])
@@ -163,12 +180,15 @@ def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType
163180 return XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL
164181 elif quant_type == QuantType .STATIC_PER_TENSOR :
165182 return XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR
166- elif quant_type == QuantType .NONE :
167- return XNNPackRecipeType .FP32
168- else :
169- raise ValueError (f"Unsupported QuantType: { quant_type } " )
183+ return XNNPackRecipeType .FP32
170184
171- def _test_model_with_factory (self , model_name : str ) -> None :
185+ def _test_model_with_factory (
186+ self ,
187+ model_name : str ,
188+ tolerance : Optional [float ] = None ,
189+ sqnr_threshold : Optional [float ] = None ,
190+ ) -> None :
191+ logging .info (f"Testing model { model_name } " )
172192 if model_name not in MODEL_NAME_TO_MODEL :
173193 self .skipTest (f"Model { model_name } not found in MODEL_NAME_TO_MODEL" )
174194 return
@@ -195,31 +215,76 @@ def _test_model_with_factory(self, model_name: str) -> None:
195215 dynamic_shapes = dynamic_shapes ,
196216 )
197217
198- # Verify outputs match
199- Tester ._assert_outputs_equal (
200- session .run_method ("forward" , example_inputs )[0 ],
201- model (* example_inputs ),
202- atol = 1e-3 ,
218+ all_artifacts = session .get_stage_artifacts ()
219+ quantized_model = all_artifacts [StageType .QUANTIZE ].data ["forward" ]
220+
221+ edge_program_manager = all_artifacts [StageType .TO_EDGE_TRANSFORM_AND_LOWER ].data
222+ lowered_module = edge_program_manager .exported_program ().module ()
223+
224+ # Check if model got lowered to xnnpack backend
225+ FileCheck ().check ("torch.ops.higher_order.executorch_call_delegate" ).run (
226+ lowered_module .code
203227 )
204228
205- @unittest .skip ("T187799178: Debugging Numerical Issues with Calibration" )
229+ if tolerance is not None :
230+ quantized_output = quantized_model (* example_inputs )
231+ lowered_output = lowered_module (* example_inputs )
232+ if model_name == "dl3" :
233+ quantized_output = quantized_output ["out" ]
234+ lowered_output = lowered_output ["out" ]
235+
236+ # lowering error
237+ try :
238+ Tester ._assert_outputs_equal (
239+ lowered_output , quantized_output , atol = tolerance , rtol = tolerance
240+ )
241+ except AssertionError as e :
242+ raise AssertionError (
243+ f"Model '{ model_name } ' lowering error check failed with tolerance { tolerance } "
244+ ) from e
245+ logging .info (
246+ f"{ self ._testMethodName } - { model_name } - lowering error passed"
247+ )
248+
249+ # verify sqnr between eager model and quantized model
250+ if sqnr_threshold is not None :
251+ original_output = model (* example_inputs )
252+ quantized_output = quantized_model (* example_inputs )
253+ # lowered_output = lowered_module(*example_inputs)
254+ if model_name == "dl3" :
255+ original_output = original_output ["out" ]
256+ quantized_output = quantized_output ["out" ]
257+ error = compute_error (original_output , quantized_output )
258+ logging .info (f"{ self ._testMethodName } - { model_name } - SQNR: { error } dB" )
259+ self .assertTrue (
260+ error > sqnr_threshold , f"Model '{ model_name } ' SQNR check failed"
261+ )
262+
206263 def test_all_models_with_recipes (self ) -> None :
207264 models_to_test = [
208- "linear" ,
209- "add" ,
210- "add_mul" ,
211- "ic3" ,
212- "mv2" ,
213- "mv3" ,
214- "resnet18" ,
215- "resnet50" ,
216- "vit" ,
217- "w2l" ,
218- "llama2" ,
265+ # Tuple format: (model_name, error tolerance, minimum sqnr)
266+ ("linear" , 1e-3 , 20 ),
267+ ("add" , 1e-3 , 20 ),
268+ ("add_mul" , 1e-3 , 20 ),
269+ ("dl3" , 1e-3 , 20 ),
270+ ("ic3" , None , None ),
271+ ("ic4" , 1e-3 , 20 ),
272+ ("mv2" , 1e-3 , None ),
273+ ("mv3" , 1e-3 , None ),
274+ ("resnet18" , 1e-3 , 20 ),
275+ ("resnet50" , 1e-3 , 20 ),
276+ ("vit" , 1e-1 , 10 ),
277+ ("w2l" , 1e-3 , 20 ),
219278 ]
220- for model_name in models_to_test :
221- with self .subTest (model = model_name ):
222- self ._test_model_with_factory (model_name )
279+ try :
280+ for model_name , tolerance , sqnr in models_to_test :
281+ with self .subTest (model = model_name ):
282+ with torch .no_grad ():
283+ self ._test_model_with_factory (model_name , tolerance , sqnr )
284+ finally :
285+ # Clean up dog.jpg file if it exists
286+ if os .path .exists ("dog.jpg" ):
287+ os .remove ("dog.jpg" )
223288
224289 def test_validate_recipe_kwargs_fp32 (self ) -> None :
225290 provider = XNNPACKRecipeProvider ()
0 commit comments