Skip to content

Commit cf4f3b9

Browse files
Re-enable model tests with recipes for xnnpack backend (#13519)
1 parent 6a875f9 commit cf4f3b9

File tree

2 files changed

+100
-31
lines changed

2 files changed

+100
-31
lines changed

backends/xnnpack/test/TARGETS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ runtime.python_test(
100100
srcs = glob([
101101
"recipes/*.py",
102102
]),
103+
env = {
104+
"HTTP_PROXY": "http://fwdproxy:8080",
105+
"HTTPS_PROXY": "http://fwdproxy:8080",
106+
},
103107
deps = [
104108
"//executorch/backends/xnnpack:xnnpack_delegate",
105109
"//executorch/export:lib",

backends/xnnpack/test/recipes/test_xnnpack_recipes.py

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
# pyre-strict
88

9+
import logging
10+
import os
911
import unittest
12+
from typing import List, Optional, Tuple
1013

1114
import torch
1215
from executorch.backends.xnnpack.recipes.xnnpack_recipe_provider import (
@@ -18,8 +21,15 @@
1821
from executorch.examples.models.model_factory import EagerModelFactory
1922
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
2023
from executorch.exir.schema import DelegateCall, Program
21-
from executorch.export import export, ExportRecipe, recipe_registry, StageType
22-
from torch import nn
24+
from executorch.export import (
25+
export,
26+
ExportRecipe,
27+
ExportSession,
28+
recipe_registry,
29+
StageType,
30+
)
31+
from torch import nn, Tensor
32+
from torch.testing import FileCheck
2333
from torch.testing._internal.common_quantization import TestHelperModules
2434
from torchao.quantization.utils import compute_error
2535

@@ -39,9 +49,12 @@ def check_fully_delegated(self, program: Program) -> None:
3949
self.assertEqual(len(instructions), 1)
4050
self.assertIsInstance(instructions[0].instr_args, DelegateCall)
4151

42-
# pyre-ignore
4352
def _compare_eager_quantized_model_outputs(
44-
self, session, example_inputs, atol: float
53+
self,
54+
# pyre-ignore[11]
55+
session: ExportSession,
56+
example_inputs: List[Tuple[Tensor]],
57+
atol: float,
4558
) -> None:
4659
"""Utility to compare eager quantized model output with session output after xnnpack lowering"""
4760
torch_export_stage_output = session.get_stage_artifacts()[
@@ -53,8 +66,12 @@ def _compare_eager_quantized_model_outputs(
5366
Tester._assert_outputs_equal(output, expected, atol=atol)
5467

5568
def _compare_eager_unquantized_model_outputs(
56-
self, session, eager_unquantized_model, example_inputs, sqnr_threshold=20
57-
):
69+
self,
70+
session: ExportSession,
71+
eager_unquantized_model: nn.Module,
72+
example_inputs: List[Tuple[Tensor]],
73+
sqnr_threshold: int = 20,
74+
) -> None:
5875
"""Utility to compare eager unquantized model output with session output using SQNR"""
5976
quantized_output = session.run_method("forward", example_inputs[0])[0]
6077
original_output = eager_unquantized_model(*example_inputs[0])
@@ -163,12 +180,15 @@ def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType
163180
return XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL
164181
elif quant_type == QuantType.STATIC_PER_TENSOR:
165182
return XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR
166-
elif quant_type == QuantType.NONE:
167-
return XNNPackRecipeType.FP32
168-
else:
169-
raise ValueError(f"Unsupported QuantType: {quant_type}")
183+
return XNNPackRecipeType.FP32
170184

171-
def _test_model_with_factory(self, model_name: str) -> None:
185+
def _test_model_with_factory(
186+
self,
187+
model_name: str,
188+
tolerance: Optional[float] = None,
189+
sqnr_threshold: Optional[float] = None,
190+
) -> None:
191+
logging.info(f"Testing model {model_name}")
172192
if model_name not in MODEL_NAME_TO_MODEL:
173193
self.skipTest(f"Model {model_name} not found in MODEL_NAME_TO_MODEL")
174194
return
@@ -195,31 +215,76 @@ def _test_model_with_factory(self, model_name: str) -> None:
195215
dynamic_shapes=dynamic_shapes,
196216
)
197217

198-
# Verify outputs match
199-
Tester._assert_outputs_equal(
200-
session.run_method("forward", example_inputs)[0],
201-
model(*example_inputs),
202-
atol=1e-3,
218+
all_artifacts = session.get_stage_artifacts()
219+
quantized_model = all_artifacts[StageType.QUANTIZE].data["forward"]
220+
221+
edge_program_manager = all_artifacts[StageType.TO_EDGE_TRANSFORM_AND_LOWER].data
222+
lowered_module = edge_program_manager.exported_program().module()
223+
224+
# Check if model got lowered to xnnpack backend
225+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
226+
lowered_module.code
203227
)
204228

205-
@unittest.skip("T187799178: Debugging Numerical Issues with Calibration")
229+
if tolerance is not None:
230+
quantized_output = quantized_model(*example_inputs)
231+
lowered_output = lowered_module(*example_inputs)
232+
if model_name == "dl3":
233+
quantized_output = quantized_output["out"]
234+
lowered_output = lowered_output["out"]
235+
236+
# lowering error
237+
try:
238+
Tester._assert_outputs_equal(
239+
lowered_output, quantized_output, atol=tolerance, rtol=tolerance
240+
)
241+
except AssertionError as e:
242+
raise AssertionError(
243+
f"Model '{model_name}' lowering error check failed with tolerance {tolerance}"
244+
) from e
245+
logging.info(
246+
f"{self._testMethodName} - {model_name} - lowering error passed"
247+
)
248+
249+
# verify sqnr between eager model and quantized model
250+
if sqnr_threshold is not None:
251+
original_output = model(*example_inputs)
252+
quantized_output = quantized_model(*example_inputs)
253+
# lowered_output = lowered_module(*example_inputs)
254+
if model_name == "dl3":
255+
original_output = original_output["out"]
256+
quantized_output = quantized_output["out"]
257+
error = compute_error(original_output, quantized_output)
258+
logging.info(f"{self._testMethodName} - {model_name} - SQNR: {error} dB")
259+
self.assertTrue(
260+
error > sqnr_threshold, f"Model '{model_name}' SQNR check failed"
261+
)
262+
206263
def test_all_models_with_recipes(self) -> None:
207264
models_to_test = [
208-
"linear",
209-
"add",
210-
"add_mul",
211-
"ic3",
212-
"mv2",
213-
"mv3",
214-
"resnet18",
215-
"resnet50",
216-
"vit",
217-
"w2l",
218-
"llama2",
265+
# Tuple format: (model_name, error tolerance, minimum sqnr)
266+
("linear", 1e-3, 20),
267+
("add", 1e-3, 20),
268+
("add_mul", 1e-3, 20),
269+
("dl3", 1e-3, 20),
270+
("ic3", None, None),
271+
("ic4", 1e-3, 20),
272+
("mv2", 1e-3, None),
273+
("mv3", 1e-3, None),
274+
("resnet18", 1e-3, 20),
275+
("resnet50", 1e-3, 20),
276+
("vit", 1e-1, 10),
277+
("w2l", 1e-3, 20),
219278
]
220-
for model_name in models_to_test:
221-
with self.subTest(model=model_name):
222-
self._test_model_with_factory(model_name)
279+
try:
280+
for model_name, tolerance, sqnr in models_to_test:
281+
with self.subTest(model=model_name):
282+
with torch.no_grad():
283+
self._test_model_with_factory(model_name, tolerance, sqnr)
284+
finally:
285+
# Clean up dog.jpg file if it exists
286+
if os.path.exists("dog.jpg"):
287+
os.remove("dog.jpg")
223288

224289
def test_validate_recipe_kwargs_fp32(self) -> None:
225290
provider = XNNPACKRecipeProvider()

0 commit comments

Comments
 (0)