6
6
7
7
# pyre-strict
8
8
9
+ import logging
10
+ import os
9
11
import unittest
12
+ from typing import List , Optional , Tuple
10
13
11
14
import torch
12
15
from executorch .backends .xnnpack .recipes .xnnpack_recipe_provider import (
18
21
from executorch .examples .models .model_factory import EagerModelFactory
19
22
from executorch .examples .xnnpack import MODEL_NAME_TO_OPTIONS , QuantType
20
23
from executorch .exir .schema import DelegateCall , Program
21
- from executorch .export import export , ExportRecipe , recipe_registry , StageType
22
- from torch import nn
24
+ from executorch .export import (
25
+ export ,
26
+ ExportRecipe ,
27
+ ExportSession ,
28
+ recipe_registry ,
29
+ StageType ,
30
+ )
31
+ from torch import nn , Tensor
32
+ from torch .testing import FileCheck
23
33
from torch .testing ._internal .common_quantization import TestHelperModules
24
34
from torchao .quantization .utils import compute_error
25
35
@@ -39,9 +49,12 @@ def check_fully_delegated(self, program: Program) -> None:
39
49
self .assertEqual (len (instructions ), 1 )
40
50
self .assertIsInstance (instructions [0 ].instr_args , DelegateCall )
41
51
42
- # pyre-ignore
43
52
def _compare_eager_quantized_model_outputs (
44
- self , session , example_inputs , atol : float
53
+ self ,
54
+ # pyre-ignore[11]
55
+ session : ExportSession ,
56
+ example_inputs : List [Tuple [Tensor ]],
57
+ atol : float ,
45
58
) -> None :
46
59
"""Utility to compare eager quantized model output with session output after xnnpack lowering"""
47
60
torch_export_stage_output = session .get_stage_artifacts ()[
@@ -53,8 +66,12 @@ def _compare_eager_quantized_model_outputs(
53
66
Tester ._assert_outputs_equal (output , expected , atol = atol )
54
67
55
68
def _compare_eager_unquantized_model_outputs (
56
- self , session , eager_unquantized_model , example_inputs , sqnr_threshold = 20
57
- ):
69
+ self ,
70
+ session : ExportSession ,
71
+ eager_unquantized_model : nn .Module ,
72
+ example_inputs : List [Tuple [Tensor ]],
73
+ sqnr_threshold : int = 20 ,
74
+ ) -> None :
58
75
"""Utility to compare eager unquantized model output with session output using SQNR"""
59
76
quantized_output = session .run_method ("forward" , example_inputs [0 ])[0 ]
60
77
original_output = eager_unquantized_model (* example_inputs [0 ])
@@ -163,12 +180,15 @@ def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType
163
180
return XNNPackRecipeType .PT2E_INT8_DYNAMIC_PER_CHANNEL
164
181
elif quant_type == QuantType .STATIC_PER_TENSOR :
165
182
return XNNPackRecipeType .PT2E_INT8_STATIC_PER_TENSOR
166
- elif quant_type == QuantType .NONE :
167
- return XNNPackRecipeType .FP32
168
- else :
169
- raise ValueError (f"Unsupported QuantType: { quant_type } " )
183
+ return XNNPackRecipeType .FP32
170
184
171
- def _test_model_with_factory (self , model_name : str ) -> None :
185
+ def _test_model_with_factory (
186
+ self ,
187
+ model_name : str ,
188
+ tolerance : Optional [float ] = None ,
189
+ sqnr_threshold : Optional [float ] = None ,
190
+ ) -> None :
191
+ logging .info (f"Testing model { model_name } " )
172
192
if model_name not in MODEL_NAME_TO_MODEL :
173
193
self .skipTest (f"Model { model_name } not found in MODEL_NAME_TO_MODEL" )
174
194
return
@@ -195,31 +215,76 @@ def _test_model_with_factory(self, model_name: str) -> None:
195
215
dynamic_shapes = dynamic_shapes ,
196
216
)
197
217
198
- # Verify outputs match
199
- Tester ._assert_outputs_equal (
200
- session .run_method ("forward" , example_inputs )[0 ],
201
- model (* example_inputs ),
202
- atol = 1e-3 ,
218
+ all_artifacts = session .get_stage_artifacts ()
219
+ quantized_model = all_artifacts [StageType .QUANTIZE ].data ["forward" ]
220
+
221
+ edge_program_manager = all_artifacts [StageType .TO_EDGE_TRANSFORM_AND_LOWER ].data
222
+ lowered_module = edge_program_manager .exported_program ().module ()
223
+
224
+ # Check if model got lowered to xnnpack backend
225
+ FileCheck ().check ("torch.ops.higher_order.executorch_call_delegate" ).run (
226
+ lowered_module .code
203
227
)
204
228
205
- @unittest .skip ("T187799178: Debugging Numerical Issues with Calibration" )
229
+ if tolerance is not None :
230
+ quantized_output = quantized_model (* example_inputs )
231
+ lowered_output = lowered_module (* example_inputs )
232
+ if model_name == "dl3" :
233
+ quantized_output = quantized_output ["out" ]
234
+ lowered_output = lowered_output ["out" ]
235
+
236
+ # lowering error
237
+ try :
238
+ Tester ._assert_outputs_equal (
239
+ lowered_output , quantized_output , atol = tolerance , rtol = tolerance
240
+ )
241
+ except AssertionError as e :
242
+ raise AssertionError (
243
+ f"Model '{ model_name } ' lowering error check failed with tolerance { tolerance } "
244
+ ) from e
245
+ logging .info (
246
+ f"{ self ._testMethodName } - { model_name } - lowering error passed"
247
+ )
248
+
249
+ # verify sqnr between eager model and quantized model
250
+ if sqnr_threshold is not None :
251
+ original_output = model (* example_inputs )
252
+ quantized_output = quantized_model (* example_inputs )
253
+ # lowered_output = lowered_module(*example_inputs)
254
+ if model_name == "dl3" :
255
+ original_output = original_output ["out" ]
256
+ quantized_output = quantized_output ["out" ]
257
+ error = compute_error (original_output , quantized_output )
258
+ logging .info (f"{ self ._testMethodName } - { model_name } - SQNR: { error } dB" )
259
+ self .assertTrue (
260
+ error > sqnr_threshold , f"Model '{ model_name } ' SQNR check failed"
261
+ )
262
+
206
263
def test_all_models_with_recipes (self ) -> None :
207
264
models_to_test = [
208
- "linear" ,
209
- "add" ,
210
- "add_mul" ,
211
- "ic3" ,
212
- "mv2" ,
213
- "mv3" ,
214
- "resnet18" ,
215
- "resnet50" ,
216
- "vit" ,
217
- "w2l" ,
218
- "llama2" ,
265
+ # Tuple format: (model_name, error tolerance, minimum sqnr)
266
+ ("linear" , 1e-3 , 20 ),
267
+ ("add" , 1e-3 , 20 ),
268
+ ("add_mul" , 1e-3 , 20 ),
269
+ ("dl3" , 1e-3 , 20 ),
270
+ ("ic3" , None , None ),
271
+ ("ic4" , 1e-3 , 20 ),
272
+ ("mv2" , 1e-3 , None ),
273
+ ("mv3" , 1e-3 , None ),
274
+ ("resnet18" , 1e-3 , 20 ),
275
+ ("resnet50" , 1e-3 , 20 ),
276
+ ("vit" , 1e-1 , 10 ),
277
+ ("w2l" , 1e-3 , 20 ),
219
278
]
220
- for model_name in models_to_test :
221
- with self .subTest (model = model_name ):
222
- self ._test_model_with_factory (model_name )
279
+ try :
280
+ for model_name , tolerance , sqnr in models_to_test :
281
+ with self .subTest (model = model_name ):
282
+ with torch .no_grad ():
283
+ self ._test_model_with_factory (model_name , tolerance , sqnr )
284
+ finally :
285
+ # Clean up dog.jpg file if it exists
286
+ if os .path .exists ("dog.jpg" ):
287
+ os .remove ("dog.jpg" )
223
288
224
289
def test_validate_recipe_kwargs_fp32 (self ) -> None :
225
290
provider = XNNPACKRecipeProvider ()
0 commit comments