diff --git a/backends/xnnpack/recipes/xnnpack_recipe_provider.py b/backends/xnnpack/recipes/xnnpack_recipe_provider.py index 8fba58c12c3..5511019fa2b 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_provider.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_provider.py @@ -25,6 +25,7 @@ get_xnnpack_executorch_backend_config, ) from executorch.export import ( + AOQuantizationConfig, BackendRecipeProvider, ExportRecipe, LoweringRecipe, @@ -144,14 +145,16 @@ def _build_int8da_intx_weight_recipe( else: weight_granularity = PerGroup(group_size=group_size) - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=weight_dtype, - weight_granularity=weight_granularity, + config = AOQuantizationConfig( + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=weight_granularity, + ) ) quant_recipe = QuantizationRecipe( quantizers=None, - ao_base_config=[config], + ao_quantization_configs=[config], ) return ExportRecipe( diff --git a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py index 679743e42d3..95c30801afe 100644 --- a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py +++ b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py @@ -19,6 +19,7 @@ from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType from executorch.exir.schema import DelegateCall, Program from executorch.export import export, ExportRecipe, recipe_registry +from export.types import StageType from torch import nn from torch.testing._internal.common_quantization import TestHelperModules @@ -38,6 +39,19 @@ def check_fully_delegated(self, program: Program) -> None: self.assertEqual(len(instructions), 1) self.assertIsInstance(instructions[0].instr_args, DelegateCall) + # pyre-ignore + def _compare_eager_quantized_model_outputs( + self, session, example_inputs, atol: float + ) -> None: + """Utility to compare eager quantized model output with session output after coreml lowering""" + source_transform_output = session.get_stage_artifacts()[ + StageType.SOURCE_TRANSFORM + ] + eager_quantized_model = source_transform_output.data["forward"] + output = session.run_method("forward", example_inputs[0])[0] + expected = eager_quantized_model(*example_inputs[0]) + self.assertTrue(torch.allclose(output, expected, atol=atol)) + def test_basic_recipe(self) -> None: m_eager = TestHelperModules.TwoLinearModule().eval() example_inputs = [(torch.randn(9, 8),)] @@ -46,13 +60,7 @@ def test_basic_recipe(self) -> None: example_inputs=example_inputs, export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32), ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-3, - ) - ) + self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3) self.check_fully_delegated(session.get_executorch_program()) def test_int8_dynamic_quant_recipe(self) -> None: @@ -70,12 +78,8 @@ def test_int8_dynamic_quant_recipe(self) -> None: example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-1, - ) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-1 ) self.check_fully_delegated(session.get_executorch_program()) @@ -95,12 +99,8 @@ def test_int8_static_quant_recipe(self) -> None: example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-1, - ) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-1 ) self.check_fully_delegated(session.get_executorch_program()) @@ -133,14 +133,10 @@ def forward(self, x) -> torch.Tensor: example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - model(*example_inputs[0]), - atol=1e-2, - ) - ) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-2 + ) def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType: # Map QuantType to corresponding recipe name. diff --git a/export/__init__.py b/export/__init__.py index d5f3826ab90..a7b165185de 100644 --- a/export/__init__.py +++ b/export/__init__.py @@ -15,12 +15,19 @@ """ from .export import export, ExportSession -from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType +from .recipe import ( + AOQuantizationConfig, + ExportRecipe, + LoweringRecipe, + QuantizationRecipe, + RecipeType, +) from .recipe_provider import BackendRecipeProvider from .recipe_registry import recipe_registry from .types import StageType __all__ = [ + "AOQuantizationConfig", "StageType", "ExportRecipe", "LoweringRecipe", diff --git a/export/recipe.py b/export/recipe.py index 8f7251cd419..086d57f3e38 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -6,7 +6,9 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass from enum import Enum, EnumMeta -from typing import List, Optional, Sequence +from typing import Callable, List, Optional, Sequence + +import torch from executorch.exir._warnings import experimental @@ -64,6 +66,20 @@ class Mode(str, Enum): RELEASE = "release" +@dataclass +class AOQuantizationConfig: + """ + Configuration for torchao quantization with optional filter function. + + Attributes: + ao_base_config: The AOBaseConfig for quantization + filter_fn: Optional filter function to selectively apply quantization + """ + + ao_base_config: AOBaseConfig + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None + + @dataclass class QuantizationRecipe: """ @@ -73,11 +89,12 @@ class QuantizationRecipe: Attributes: quantizers: Optional list of quantizers for model quantization - ao_base_config: Optional list of AO base configurations + ao_quantization_configs: Optional list of AOQuantizationConfig objects that pair + AOBaseConfig with optional filter functions """ quantizers: Optional[List[Quantizer]] = None - ao_base_config: Optional[List[AOBaseConfig]] = None + ao_quantization_configs: Optional[List[AOQuantizationConfig]] = None def get_quantizers(self) -> Optional[List[Quantizer]]: """ diff --git a/export/stages.py b/export/stages.py index dd22155e929..6c29f3f1c96 100644 --- a/export/stages.py +++ b/export/stages.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import logging from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Sequence @@ -20,7 +21,6 @@ from torch._export.pass_base import PassType from torchao.quantization import quantize_ from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torchao.quantization.pt2e.quantizer import ComposableQuantizer from torchao.utils import unwrap_tensor_subclass @@ -287,7 +287,7 @@ def run(self, artifact: PipelineArtifact) -> None: """ if ( not self._quantization_recipe - or not self._quantization_recipe.ao_base_config + or not self._quantization_recipe.ao_quantization_configs ): logging.info( "Quantization recipe is invalid to run SourceTransform, returning original artifact" @@ -298,15 +298,14 @@ def run(self, artifact: PipelineArtifact) -> None: assert isinstance(artifact.data, dict) # Store the original models - self._transformed_models = artifact.data + self._transformed_models = copy.deepcopy(artifact.data) # Apply torchao quantize_ to each model - for method_name, model in artifact.data.items(): + for _, model in artifact.data.items(): # pyre-ignore - for config in self._quantization_recipe.ao_base_config: - quantize_(model, config) + for ao_config in self._quantization_recipe.ao_quantization_configs: + quantize_(model, ao_config.ao_base_config, ao_config.filter_fn) unwrap_tensor_subclass(model) - self._transformed_models[method_name] = model self._artifact = artifact.copy_with_new_data(self._transformed_models) @@ -331,6 +330,38 @@ def valid_predecessor_stages(self) -> List["StageType"]: def can_start_pipeline(self) -> bool: return True + def _get_quantizer_for_prepare_pt2e(self, quantizers: List[Any]): + torch_ao_quantizers = [] + torchao_pt2e_quantizers = [] + + for quantizer in quantizers: + from torchao.quantization.pt2e.quantizer import ( + Quantizer as TorchAOPT2EQuantizer, + ) + + if isinstance(quantizer, TorchAOPT2EQuantizer): + torchao_pt2e_quantizers.append(quantizer) + else: + torch_ao_quantizers.append(quantizer) + + if torch_ao_quantizers and torchao_pt2e_quantizers: + raise ValueError("Mixed quantizer types are not supported") + if len(torch_ao_quantizers) > 1: + raise ValueError( + "Multiple quantizers of torch.ao.quantization.quantizer not supported" + ) + + if torch_ao_quantizers: + # prepare_pt2e has backward compat with torch.ao quantizer + return torch_ao_quantizers[0] + elif torchao_pt2e_quantizers: + # Multiple torchao quantizers - use ComposableQuantizer + from torchao.quantization.pt2e.quantizer import ComposableQuantizer + + return ComposableQuantizer(torchao_pt2e_quantizers) + else: + raise ValueError("No quantizers detected") + def run(self, artifact: PipelineArtifact) -> None: if not self._quantization_recipe or not self._quantization_recipe.quantizers: logging.info( @@ -355,11 +386,10 @@ def run(self, artifact: PipelineArtifact) -> None: inputs = example_inputs[method_name][0] captured_graph = torch.export.export(model, inputs, strict=True).module() - composed_quantizer = ComposableQuantizer( - # pyre-ignore + quantizer = self._get_quantizer_for_prepare_pt2e( self._quantization_recipe.quantizers ) - prepared_model = prepare_pt2e(captured_graph, composed_quantizer) + prepared_model = prepare_pt2e(captured_graph, quantizer) for calibration_input in example_inputs[method_name]: prepared_model(*calibration_input) diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index 92aeebb7304..8fab708f8d5 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -12,7 +12,11 @@ import torch from executorch.export import ExportRecipe, ExportSession -from executorch.export.recipe import LoweringRecipe, QuantizationRecipe +from executorch.export.recipe import ( + AOQuantizationConfig, + LoweringRecipe, + QuantizationRecipe, +) from executorch.export.stages import PipelineArtifact from executorch.export.types import StageType @@ -20,7 +24,7 @@ class SimpleTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.linear = torch.nn.Linear(10, 5) + self.linear: torch.nn.Module = torch.nn.Linear(10, 5) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -448,7 +452,7 @@ def test_pipeline_building_with_all_recipes(self) -> None: """Test pipeline building with quantization and lowering recipes.""" # Create comprehensive recipes quant_recipe = QuantizationRecipe( - ao_base_config=[Mock()], + ao_quantization_configs=[AOQuantizationConfig(Mock())], quantizers=[Mock()], ) lowering_recipe = LoweringRecipe( diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 2b3e533723a..0915de24499 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -11,7 +11,7 @@ import torch from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager -from executorch.export import QuantizationRecipe +from executorch.export import AOQuantizationConfig, QuantizationRecipe from executorch.export.stages import ( EdgeTransformAndLowerStage, ExecutorchStage, @@ -29,7 +29,7 @@ class SimpleTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.linear = torch.nn.Linear(10, 5) + self.linear: torch.nn.Module = torch.nn.Linear(10, 5) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -163,7 +163,7 @@ def setUp(self) -> None: def test_source_transform_stage_no_quantization(self) -> None: mock_recipe = Mock(spec=QuantizationRecipe) - mock_recipe.ao_base_config = None + mock_recipe.ao_quantization_configs = None stage = SourceTransformStage(mock_recipe) artifact = PipelineArtifact(data=self.models_dict, context={}) @@ -174,12 +174,19 @@ def test_source_transform_stage_no_quantization(self) -> None: @patch("executorch.export.stages.quantize_") @patch("executorch.export.stages.unwrap_tensor_subclass") - def test_run_with_ao_base_config( + def test_run_with_ao_quantization_configs( self, mock_unwrap: Mock, mock_quantize: Mock ) -> None: - mock_config = Mock() + from torchao.core.config import AOBaseConfig + + mock_config = Mock(spec=AOBaseConfig) + mock_filter_fn = Mock() + # pyre-ignore[28]: Unexpected keyword argument error is a false positive for dataclass + mock_ao_config: AOQuantizationConfig = AOQuantizationConfig( + ao_base_config=mock_config, filter_fn=mock_filter_fn + ) mock_recipe = Mock(spec=QuantizationRecipe) - mock_recipe.ao_base_config = [mock_config] + mock_recipe.ao_quantization_configs = [mock_ao_config] stage = SourceTransformStage(mock_recipe) @@ -188,7 +195,7 @@ def test_run_with_ao_base_config( stage.run(artifact) # Verify quantize_ was called with the model and config - mock_quantize.assert_called_once_with(self.model, mock_config) + mock_quantize.assert_called_once_with(self.model, mock_config, mock_filter_fn) # Verify unwrap_tensor_subclass was called with the model mock_unwrap.assert_called_once_with(self.model)