-
Notifications
You must be signed in to change notification settings - Fork 691
Add TorchAO wrapper config to allow filter_fn for quantize_ #13264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
59affe6
37bdc0b
4eb6a03
7dab762
fab2d54
b5c56a2
cbfe8bb
482f3d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -19,8 +19,10 @@ | |||||
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType | ||||||
from executorch.exir.schema import DelegateCall, Program | ||||||
from executorch.export import export, ExportRecipe, recipe_registry | ||||||
from export.types import StageType | ||||||
from torch import nn | ||||||
from torch.testing._internal.common_quantization import TestHelperModules | ||||||
from torchao.quantization.utils import compute_error | ||||||
|
||||||
|
||||||
class TestXnnpackRecipes(unittest.TestCase): | ||||||
|
@@ -38,6 +40,29 @@ def check_fully_delegated(self, program: Program) -> None: | |||||
self.assertEqual(len(instructions), 1) | ||||||
self.assertIsInstance(instructions[0].instr_args, DelegateCall) | ||||||
|
||||||
# pyre-ignore | ||||||
def _compare_eager_quantized_model_outputs( | ||||||
self, session, example_inputs, atol: float | ||||||
) -> None: | ||||||
"""Utility to compare eager quantized model output with session output after xnnpack lowering""" | ||||||
torch_export_stage_output = session.get_stage_artifacts()[ | ||||||
StageType.TORCH_EXPORT | ||||||
] | ||||||
eager_quantized_model = torch_export_stage_output.data["forward"].module() | ||||||
output = session.run_method("forward", example_inputs[0])[0] | ||||||
expected = eager_quantized_model(*example_inputs[0]) | ||||||
Tester._assert_outputs_equal(output, expected, atol=atol) | ||||||
|
||||||
def _compare_eager_unquantized_model_outputs( | ||||||
self, session, eager_unquantized_model, example_inputs, sqnr_threshold=20 | ||||||
): | ||||||
"""Utility to compare eager unquantized model output with session output using SQNR""" | ||||||
quantized_output = session.run_method("forward", example_inputs[0])[0] | ||||||
original_output = eager_unquantized_model(*example_inputs[0]) | ||||||
error = compute_error(original_output, quantized_output) | ||||||
print(f"{self._testMethodName} - SQNR: {error} dB") | ||||||
self.assertTrue(error > sqnr_threshold) | ||||||
|
||||||
def test_basic_recipe(self) -> None: | ||||||
m_eager = TestHelperModules.TwoLinearModule().eval() | ||||||
example_inputs = [(torch.randn(9, 8),)] | ||||||
|
@@ -46,18 +71,13 @@ def test_basic_recipe(self) -> None: | |||||
example_inputs=example_inputs, | ||||||
export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32), | ||||||
) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
session.run_method("forward", example_inputs[0])[0], | ||||||
m_eager(*example_inputs[0]), | ||||||
atol=1e-3, | ||||||
) | ||||||
) | ||||||
self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3) | ||||||
self.check_fully_delegated(session.get_executorch_program()) | ||||||
self._compare_eager_unquantized_model_outputs(session, m_eager, example_inputs) | ||||||
|
||||||
def test_int8_dynamic_quant_recipe(self) -> None: | ||||||
test_cases = [ | ||||||
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL), | ||||||
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL), | ||||||
] | ||||||
|
||||||
for export_recipe in test_cases: | ||||||
|
@@ -70,19 +90,18 @@ def test_int8_dynamic_quant_recipe(self) -> None: | |||||
example_inputs=example_inputs, | ||||||
export_recipe=export_recipe, | ||||||
) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
session.run_method("forward", example_inputs[0])[0], | ||||||
m_eager(*example_inputs[0]), | ||||||
atol=1e-1, | ||||||
) | ||||||
self._compare_eager_quantized_model_outputs( | ||||||
session, example_inputs, 1e-1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. atol? Why is this so high for two linears?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, i think 1e-2 is working on my mac, will check if linux passes on CI. Nevertheless, i'm updating the tolerance tests similar to CoreML (let me know if there is any objection) to use sqnr to compare eager model vs lowered model output. But use tolerance checks to compare post quantized model and lowered model. |
||||||
) | ||||||
self.check_fully_delegated(session.get_executorch_program()) | ||||||
self._compare_eager_unquantized_model_outputs( | ||||||
session, m_eager, example_inputs | ||||||
) | ||||||
|
||||||
def test_int8_static_quant_recipe(self) -> None: | ||||||
test_cases = [ | ||||||
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_CHANNEL), | ||||||
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_TENSOR), | ||||||
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL), | ||||||
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR), | ||||||
] | ||||||
|
||||||
for export_recipe in test_cases: | ||||||
|
@@ -95,14 +114,13 @@ def test_int8_static_quant_recipe(self) -> None: | |||||
example_inputs=example_inputs, | ||||||
export_recipe=export_recipe, | ||||||
) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
session.run_method("forward", example_inputs[0])[0], | ||||||
m_eager(*example_inputs[0]), | ||||||
atol=1e-1, | ||||||
) | ||||||
self._compare_eager_quantized_model_outputs( | ||||||
session, example_inputs, 1e-2 | ||||||
) | ||||||
self.check_fully_delegated(session.get_executorch_program()) | ||||||
self._compare_eager_unquantized_model_outputs( | ||||||
session, m_eager, example_inputs | ||||||
) | ||||||
|
||||||
def test_8a4w_recipe(self) -> None: | ||||||
class SimpleLinearModel(nn.Module): | ||||||
|
@@ -116,10 +134,10 @@ def forward(self, x) -> torch.Tensor: | |||||
|
||||||
test_cases = [ | ||||||
ExportRecipe.get_recipe( | ||||||
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL, | ||||||
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL, | ||||||
), | ||||||
ExportRecipe.get_recipe( | ||||||
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, | ||||||
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, | ||||||
group_size=32, | ||||||
), | ||||||
] | ||||||
|
@@ -133,23 +151,22 @@ def forward(self, x) -> torch.Tensor: | |||||
example_inputs=example_inputs, | ||||||
export_recipe=export_recipe, | ||||||
) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
session.run_method("forward", example_inputs[0])[0], | ||||||
model(*example_inputs[0]), | ||||||
atol=1e-2, | ||||||
) | ||||||
) | ||||||
self.check_fully_delegated(session.get_executorch_program()) | ||||||
self._compare_eager_quantized_model_outputs( | ||||||
session, example_inputs, 1e-3 | ||||||
) | ||||||
self._compare_eager_unquantized_model_outputs( | ||||||
session, model, example_inputs, sqnr_threshold=15 | ||||||
) | ||||||
|
||||||
def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType: | ||||||
# Map QuantType to corresponding recipe name. | ||||||
if quant_type == QuantType.STATIC_PER_CHANNEL: | ||||||
return XNNPackRecipeType.INT8_STATIC_PER_CHANNEL | ||||||
return XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL | ||||||
elif quant_type == QuantType.DYNAMIC_PER_CHANNEL: | ||||||
return XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL | ||||||
return XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL | ||||||
elif quant_type == QuantType.STATIC_PER_TENSOR: | ||||||
return XNNPackRecipeType.INT8_STATIC_PER_TENSOR | ||||||
return XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR | ||||||
elif quant_type == QuantType.NONE: | ||||||
return XNNPackRecipeType.FP32 | ||||||
else: | ||||||
|
@@ -224,12 +241,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size( | |||||
|
||||||
# Should not raise any exception | ||||||
recipe_w_default_group = provider.create_recipe( | ||||||
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR | ||||||
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR | ||||||
) | ||||||
self.assertIsNotNone(recipe_w_default_group) | ||||||
|
||||||
recipe = provider.create_recipe( | ||||||
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, group_size=64 | ||||||
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, | ||||||
group_size=64, | ||||||
) | ||||||
self.assertIsNotNone(recipe) | ||||||
|
||||||
|
@@ -240,7 +258,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size( | |||||
|
||||||
with self.assertRaises(ValueError) as cm: | ||||||
provider.create_recipe( | ||||||
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, | ||||||
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, | ||||||
group_size="32", # String instead of int | ||||||
) | ||||||
|
||||||
|
Uh oh!
There was an error while loading. Please reload this page.