-
Notifications
You must be signed in to change notification settings - Fork 646
Add TorchAO wrapper config to allow filter_fn for quantize_ #13264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
59affe6
37bdc0b
4eb6a03
7dab762
fab2d54
b5c56a2
cbfe8bb
482f3d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -19,6 +19,7 @@ | |||||
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType | ||||||
from executorch.exir.schema import DelegateCall, Program | ||||||
from executorch.export import export, ExportRecipe, recipe_registry | ||||||
from export.types import StageType | ||||||
from torch import nn | ||||||
from torch.testing._internal.common_quantization import TestHelperModules | ||||||
|
||||||
|
@@ -38,6 +39,19 @@ def check_fully_delegated(self, program: Program) -> None: | |||||
self.assertEqual(len(instructions), 1) | ||||||
self.assertIsInstance(instructions[0].instr_args, DelegateCall) | ||||||
|
||||||
# pyre-ignore | ||||||
def _compare_eager_quantized_model_outputs( | ||||||
self, session, example_inputs, atol: float | ||||||
) -> None: | ||||||
"""Utility to compare eager quantized model output with session output after coreml lowering""" | ||||||
source_transform_output = session.get_stage_artifacts()[ | ||||||
StageType.SOURCE_TRANSFORM | ||||||
] | ||||||
eager_quantized_model = source_transform_output.data["forward"] | ||||||
output = session.run_method("forward", example_inputs[0])[0] | ||||||
expected = eager_quantized_model(*example_inputs[0]) | ||||||
self.assertTrue(torch.allclose(output, expected, atol=atol)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to print more stats if this fails - see https://github.com/pytorch/executorch/blob/main/backends/test/harness/tester.py#L337 |
||||||
|
||||||
def test_basic_recipe(self) -> None: | ||||||
m_eager = TestHelperModules.TwoLinearModule().eval() | ||||||
example_inputs = [(torch.randn(9, 8),)] | ||||||
|
@@ -46,13 +60,7 @@ def test_basic_recipe(self) -> None: | |||||
example_inputs=example_inputs, | ||||||
export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32), | ||||||
) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
session.run_method("forward", example_inputs[0])[0], | ||||||
m_eager(*example_inputs[0]), | ||||||
atol=1e-3, | ||||||
) | ||||||
) | ||||||
self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3) | ||||||
self.check_fully_delegated(session.get_executorch_program()) | ||||||
|
||||||
def test_int8_dynamic_quant_recipe(self) -> None: | ||||||
|
@@ -70,12 +78,8 @@ def test_int8_dynamic_quant_recipe(self) -> None: | |||||
example_inputs=example_inputs, | ||||||
export_recipe=export_recipe, | ||||||
) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
session.run_method("forward", example_inputs[0])[0], | ||||||
m_eager(*example_inputs[0]), | ||||||
atol=1e-1, | ||||||
) | ||||||
self._compare_eager_quantized_model_outputs( | ||||||
session, example_inputs, 1e-1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. atol? Why is this so high for two linears?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, i think 1e-2 is working on my mac, will check if linux passes on CI. Nevertheless, i'm updating the tolerance tests similar to CoreML (let me know if there is any objection) to use sqnr to compare eager model vs lowered model output. But use tolerance checks to compare post quantized model and lowered model. |
||||||
) | ||||||
self.check_fully_delegated(session.get_executorch_program()) | ||||||
|
||||||
|
@@ -95,12 +99,8 @@ def test_int8_static_quant_recipe(self) -> None: | |||||
example_inputs=example_inputs, | ||||||
export_recipe=export_recipe, | ||||||
) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
session.run_method("forward", example_inputs[0])[0], | ||||||
m_eager(*example_inputs[0]), | ||||||
atol=1e-1, | ||||||
) | ||||||
self._compare_eager_quantized_model_outputs( | ||||||
session, example_inputs, 1e-1 | ||||||
) | ||||||
self.check_fully_delegated(session.get_executorch_program()) | ||||||
|
||||||
|
@@ -133,14 +133,10 @@ def forward(self, x) -> torch.Tensor: | |||||
example_inputs=example_inputs, | ||||||
export_recipe=export_recipe, | ||||||
) | ||||||
self.assertTrue( | ||||||
torch.allclose( | ||||||
session.run_method("forward", example_inputs[0])[0], | ||||||
model(*example_inputs[0]), | ||||||
atol=1e-2, | ||||||
) | ||||||
) | ||||||
self.check_fully_delegated(session.get_executorch_program()) | ||||||
self._compare_eager_quantized_model_outputs( | ||||||
session, example_inputs, 1e-2 | ||||||
) | ||||||
|
||||||
def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType: | ||||||
# Map QuantType to corresponding recipe name. | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import copy | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Callable, Dict, List, Optional, Sequence | ||
|
@@ -20,7 +21,6 @@ | |
from torch._export.pass_base import PassType | ||
from torchao.quantization import quantize_ | ||
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e | ||
from torchao.quantization.pt2e.quantizer import ComposableQuantizer | ||
from torchao.utils import unwrap_tensor_subclass | ||
|
||
|
||
|
@@ -287,7 +287,7 @@ def run(self, artifact: PipelineArtifact) -> None: | |
""" | ||
if ( | ||
not self._quantization_recipe | ||
or not self._quantization_recipe.ao_base_config | ||
or not self._quantization_recipe.ao_quantization_configs | ||
): | ||
logging.info( | ||
"Quantization recipe is invalid to run SourceTransform, returning original artifact" | ||
|
@@ -298,15 +298,14 @@ def run(self, artifact: PipelineArtifact) -> None: | |
assert isinstance(artifact.data, dict) | ||
|
||
# Store the original models | ||
self._transformed_models = artifact.data | ||
self._transformed_models = copy.deepcopy(artifact.data) | ||
|
||
# Apply torchao quantize_ to each model | ||
for method_name, model in artifact.data.items(): | ||
for _, model in artifact.data.items(): | ||
# pyre-ignore | ||
for config in self._quantization_recipe.ao_base_config: | ||
quantize_(model, config) | ||
for ao_config in self._quantization_recipe.ao_quantization_configs: | ||
quantize_(model, ao_config.ao_base_config, ao_config.filter_fn) | ||
unwrap_tensor_subclass(model) | ||
self._transformed_models[method_name] = model | ||
|
||
self._artifact = artifact.copy_with_new_data(self._transformed_models) | ||
|
||
|
@@ -331,6 +330,38 @@ def valid_predecessor_stages(self) -> List["StageType"]: | |
def can_start_pipeline(self) -> bool: | ||
return True | ||
|
||
def _get_quantizer_for_prepare_pt2e(self, quantizers: List[Any]): | ||
torch_ao_quantizers = [] | ||
abhinaykukkadapu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torchao_pt2e_quantizers = [] | ||
|
||
for quantizer in quantizers: | ||
from torchao.quantization.pt2e.quantizer import ( | ||
Quantizer as TorchAOPT2EQuantizer, | ||
) | ||
|
||
if isinstance(quantizer, TorchAOPT2EQuantizer): | ||
torchao_pt2e_quantizers.append(quantizer) | ||
else: | ||
torch_ao_quantizers.append(quantizer) | ||
|
||
if torch_ao_quantizers and torchao_pt2e_quantizers: | ||
raise ValueError("Mixed quantizer types are not supported") | ||
if len(torch_ao_quantizers) > 1: | ||
raise ValueError( | ||
"Multiple quantizers of torch.ao.quantization.quantizer not supported" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't torchao already detect this and give an error if mixing? I thought I added that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be the torchao version is different? |
||
) | ||
|
||
if torch_ao_quantizers: | ||
# prepare_pt2e has backward compat with torch.ao quantizer | ||
return torch_ao_quantizers[0] | ||
elif torchao_pt2e_quantizers: | ||
# Multiple torchao quantizers - use ComposableQuantizer | ||
from torchao.quantization.pt2e.quantizer import ComposableQuantizer | ||
|
||
return ComposableQuantizer(torchao_pt2e_quantizers) | ||
abhinaykukkadapu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise ValueError("No quantizers detected") | ||
|
||
def run(self, artifact: PipelineArtifact) -> None: | ||
if not self._quantization_recipe or not self._quantization_recipe.quantizers: | ||
logging.info( | ||
|
@@ -355,11 +386,10 @@ def run(self, artifact: PipelineArtifact) -> None: | |
inputs = example_inputs[method_name][0] | ||
captured_graph = torch.export.export(model, inputs, strict=True).module() | ||
|
||
composed_quantizer = ComposableQuantizer( | ||
# pyre-ignore | ||
quantizer = self._get_quantizer_for_prepare_pt2e( | ||
self._quantization_recipe.quantizers | ||
) | ||
prepared_model = prepare_pt2e(captured_graph, composed_quantizer) | ||
prepared_model = prepare_pt2e(captured_graph, quantizer) | ||
|
||
for calibration_input in example_inputs[method_name]: | ||
prepared_model(*calibration_input) | ||
|
Uh oh!
There was an error while loading. Please reload this page.