diff --git a/backends/xnnpack/recipes/xnnpack_recipe_provider.py b/backends/xnnpack/recipes/xnnpack_recipe_provider.py index 8fba58c12c3..436eb2db158 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_provider.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_provider.py @@ -25,6 +25,7 @@ get_xnnpack_executorch_backend_config, ) from executorch.export import ( + AOQuantizationConfig, BackendRecipeProvider, ExportRecipe, LoweringRecipe, @@ -57,31 +58,37 @@ def create_recipe( if recipe_type == XNNPackRecipeType.FP32: return self._build_fp32_recipe(recipe_type) - elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL: + elif recipe_type == XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL: return self._build_quantized_recipe( recipe_type, is_per_channel=True, is_dynamic=True ) - elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_CHANNEL: + elif recipe_type == XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL: return self._build_quantized_recipe( recipe_type, is_per_channel=True, is_dynamic=False ) - elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_TENSOR: + elif recipe_type == XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR: return self._build_quantized_recipe( recipe_type, is_per_channel=False, is_dynamic=False ) - elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL: - return self._build_int8da_intx_weight_recipe( + elif ( + recipe_type + == XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL + ): + return self._build_torchao_quantized_recipe( recipe_type=recipe_type, is_per_channel=True, weight_dtype=torch.int4, ) - elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR: + elif ( + recipe_type + == XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR + ): group_size = kwargs.get("group_size", 32) - return self._build_int8da_intx_weight_recipe( + return self._build_torchao_quantized_recipe( recipe_type=recipe_type, is_per_channel=False, weight_dtype=torch.int4, @@ -132,7 +139,7 @@ def _build_quantized_recipe( executorch_backend_config=get_xnnpack_executorch_backend_config(), ) - def _build_int8da_intx_weight_recipe( + def _build_torchao_quantized_recipe( self, recipe_type: RecipeType, is_per_channel: bool = True, @@ -141,17 +148,21 @@ def _build_int8da_intx_weight_recipe( ) -> ExportRecipe: if is_per_channel: weight_granularity = PerAxis(axis=0) + assert weight_dtype == torch.int4 or weight_dtype == torch.int8 else: weight_granularity = PerGroup(group_size=group_size) + assert weight_dtype == torch.int4 - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=weight_dtype, - weight_granularity=weight_granularity, + config = AOQuantizationConfig( + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=weight_granularity, + ) ) quant_recipe = QuantizationRecipe( quantizers=None, - ao_base_config=[config], + ao_quantization_configs=[config], ) return ExportRecipe( @@ -162,7 +173,10 @@ def _build_int8da_intx_weight_recipe( ) def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: - if recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR: + if ( + recipe_type + == XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR + ): expected_keys = {"group_size"} unexpected = set(kwargs.keys()) - expected_keys if unexpected: diff --git a/backends/xnnpack/recipes/xnnpack_recipe_types.py b/backends/xnnpack/recipes/xnnpack_recipe_types.py index 5675c3a5ffa..61117b94502 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_types.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_types.py @@ -13,19 +13,22 @@ class XNNPackRecipeType(RecipeType): """XNNPACK-specific recipe types""" FP32 = "fp32" + + ## PT2E-based quantization recipes # INT8 Dynamic Quantization - INT8_DYNAMIC_PER_CHANNEL = "int8_dynamic_per_channel" + PT2E_INT8_DYNAMIC_PER_CHANNEL = "pt2e_int8_dynamic_per_channel" + # INT8 Static Quantization, needs calibration dataset + PT2E_INT8_STATIC_PER_CHANNEL = "pt2e_int8_static_per_channel" + PT2E_INT8_STATIC_PER_TENSOR = "pt2e_int8_static_per_tensor" + + ## TorchAO-based quantization recipes # INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0 - INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8da_int4w_per_channel" + TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = ( + "torchao_int8da_int4w_per_channel" + ) # INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32 # can be overriden by group_size kwarg - INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8da_int4w_per_tensor" - # INT8 Static Activations INT4 Weight Quantization - INT8_STATIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8a_int4w_per_channel" - INT8_STATIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8a_int44w_per_tensor" - # INT8 Static Quantization, needs calibration dataset - INT8_STATIC_PER_CHANNEL = "int8_static_per_channel" - INT8_STATIC_PER_TENSOR = "int8_static_per_tensor" + TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "torchao_int8da_int4w_per_tensor" @classmethod def get_backend_name(cls) -> str: diff --git a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py index 679743e42d3..565b71eab71 100644 --- a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py +++ b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py @@ -18,9 +18,10 @@ from executorch.examples.models.model_factory import EagerModelFactory from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType from executorch.exir.schema import DelegateCall, Program -from executorch.export import export, ExportRecipe, recipe_registry +from executorch.export import export, ExportRecipe, recipe_registry, StageType from torch import nn from torch.testing._internal.common_quantization import TestHelperModules +from torchao.quantization.utils import compute_error class TestXnnpackRecipes(unittest.TestCase): @@ -38,6 +39,29 @@ def check_fully_delegated(self, program: Program) -> None: self.assertEqual(len(instructions), 1) self.assertIsInstance(instructions[0].instr_args, DelegateCall) + # pyre-ignore + def _compare_eager_quantized_model_outputs( + self, session, example_inputs, atol: float + ) -> None: + """Utility to compare eager quantized model output with session output after xnnpack lowering""" + torch_export_stage_output = session.get_stage_artifacts()[ + StageType.TORCH_EXPORT + ] + eager_quantized_model = torch_export_stage_output.data["forward"].module() + output = session.run_method("forward", example_inputs[0])[0] + expected = eager_quantized_model(*example_inputs[0]) + Tester._assert_outputs_equal(output, expected, atol=atol) + + def _compare_eager_unquantized_model_outputs( + self, session, eager_unquantized_model, example_inputs, sqnr_threshold=20 + ): + """Utility to compare eager unquantized model output with session output using SQNR""" + quantized_output = session.run_method("forward", example_inputs[0])[0] + original_output = eager_unquantized_model(*example_inputs[0]) + error = compute_error(original_output, quantized_output) + print(f"{self._testMethodName} - SQNR: {error} dB") + self.assertTrue(error > sqnr_threshold) + def test_basic_recipe(self) -> None: m_eager = TestHelperModules.TwoLinearModule().eval() example_inputs = [(torch.randn(9, 8),)] @@ -46,18 +70,13 @@ def test_basic_recipe(self) -> None: example_inputs=example_inputs, export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32), ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-3, - ) - ) + self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_unquantized_model_outputs(session, m_eager, example_inputs) def test_int8_dynamic_quant_recipe(self) -> None: test_cases = [ - ExportRecipe.get_recipe(XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL), + ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL), ] for export_recipe in test_cases: @@ -70,19 +89,18 @@ def test_int8_dynamic_quant_recipe(self) -> None: example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-1, - ) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-1 ) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_unquantized_model_outputs( + session, m_eager, example_inputs + ) def test_int8_static_quant_recipe(self) -> None: test_cases = [ - ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_CHANNEL), - ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_TENSOR), + ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL), + ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR), ] for export_recipe in test_cases: @@ -95,14 +113,13 @@ def test_int8_static_quant_recipe(self) -> None: example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-1, - ) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-2 ) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_unquantized_model_outputs( + session, m_eager, example_inputs + ) def test_8a4w_recipe(self) -> None: class SimpleLinearModel(nn.Module): @@ -116,40 +133,36 @@ def forward(self, x) -> torch.Tensor: test_cases = [ ExportRecipe.get_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL, + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL, ), ExportRecipe.get_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, - group_size=32, + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, + group_size=8, ), ] for export_recipe in test_cases: with self.subTest(export_recipe=export_recipe): - model = SimpleLinearModel() + model = SimpleLinearModel().eval() example_inputs = [(torch.randn(1, 32),)] session = export( model=model, example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - model(*example_inputs[0]), - atol=1e-2, - ) - ) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-3 + ) def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType: # Map QuantType to corresponding recipe name. if quant_type == QuantType.STATIC_PER_CHANNEL: - return XNNPackRecipeType.INT8_STATIC_PER_CHANNEL + return XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL elif quant_type == QuantType.DYNAMIC_PER_CHANNEL: - return XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL + return XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL elif quant_type == QuantType.STATIC_PER_TENSOR: - return XNNPackRecipeType.INT8_STATIC_PER_TENSOR + return XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR elif quant_type == QuantType.NONE: return XNNPackRecipeType.FP32 else: @@ -224,12 +237,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size( # Should not raise any exception recipe_w_default_group = provider.create_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ) self.assertIsNotNone(recipe_w_default_group) recipe = provider.create_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, group_size=64 + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, + group_size=64, ) self.assertIsNotNone(recipe) @@ -240,7 +254,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size( with self.assertRaises(ValueError) as cm: provider.create_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, group_size="32", # String instead of int ) diff --git a/export/__init__.py b/export/__init__.py index d5f3826ab90..a7b165185de 100644 --- a/export/__init__.py +++ b/export/__init__.py @@ -15,12 +15,19 @@ """ from .export import export, ExportSession -from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType +from .recipe import ( + AOQuantizationConfig, + ExportRecipe, + LoweringRecipe, + QuantizationRecipe, + RecipeType, +) from .recipe_provider import BackendRecipeProvider from .recipe_registry import recipe_registry from .types import StageType __all__ = [ + "AOQuantizationConfig", "StageType", "ExportRecipe", "LoweringRecipe", diff --git a/export/recipe.py b/export/recipe.py index 8f7251cd419..086d57f3e38 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -6,7 +6,9 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass from enum import Enum, EnumMeta -from typing import List, Optional, Sequence +from typing import Callable, List, Optional, Sequence + +import torch from executorch.exir._warnings import experimental @@ -64,6 +66,20 @@ class Mode(str, Enum): RELEASE = "release" +@dataclass +class AOQuantizationConfig: + """ + Configuration for torchao quantization with optional filter function. + + Attributes: + ao_base_config: The AOBaseConfig for quantization + filter_fn: Optional filter function to selectively apply quantization + """ + + ao_base_config: AOBaseConfig + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None + + @dataclass class QuantizationRecipe: """ @@ -73,11 +89,12 @@ class QuantizationRecipe: Attributes: quantizers: Optional list of quantizers for model quantization - ao_base_config: Optional list of AO base configurations + ao_quantization_configs: Optional list of AOQuantizationConfig objects that pair + AOBaseConfig with optional filter functions """ quantizers: Optional[List[Quantizer]] = None - ao_base_config: Optional[List[AOBaseConfig]] = None + ao_quantization_configs: Optional[List[AOQuantizationConfig]] = None def get_quantizers(self) -> Optional[List[Quantizer]]: """ diff --git a/export/stages.py b/export/stages.py index f4de59a9b7a..2b3f8a42440 100644 --- a/export/stages.py +++ b/export/stages.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import logging from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Sequence @@ -20,7 +21,10 @@ from torch._export.pass_base import PassType from torchao.quantization import quantize_ from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torchao.quantization.pt2e.quantizer import ComposableQuantizer +from torchao.quantization.pt2e.quantizer import ( + ComposableQuantizer, + Quantizer as TorchAOPT2EQuantizer, +) from torchao.utils import unwrap_tensor_subclass @@ -289,7 +293,7 @@ def run(self, artifact: PipelineArtifact) -> None: """ if ( not self._quantization_recipe - or not self._quantization_recipe.ao_base_config + or not self._quantization_recipe.ao_quantization_configs ): logging.info( "Quantization recipe is invalid to run SourceTransform, returning original artifact" @@ -300,15 +304,14 @@ def run(self, artifact: PipelineArtifact) -> None: assert isinstance(artifact.data, dict) # Store the original models - self._transformed_models = artifact.data + self._transformed_models = copy.deepcopy(artifact.data) # Apply torchao quantize_ to each model - for method_name, model in artifact.data.items(): + for _, model in artifact.data.items(): # pyre-ignore - for config in self._quantization_recipe.ao_base_config: - quantize_(model, config) + for ao_config in self._quantization_recipe.ao_quantization_configs: + quantize_(model, ao_config.ao_base_config, ao_config.filter_fn) unwrap_tensor_subclass(model) - self._transformed_models[method_name] = model self._artifact = artifact.copy_with_new_data(self._transformed_models) @@ -333,6 +336,36 @@ def valid_predecessor_stages(self) -> List["StageType"]: def can_start_pipeline(self) -> bool: return True + def _get_quantizer_for_prepare_pt2e(self, quantizers: List[Any]): + torch_ao_quantizers = [] + torchao_pt2e_quantizers = [] + + for quantizer in quantizers: + if isinstance(quantizer, TorchAOPT2EQuantizer): + torchao_pt2e_quantizers.append(quantizer) + else: + # torch.ao quantizer support will soon be deprecated, remove this once CoreML moves to torchao quantizer + logging.warning( + f"torch.ao quantizer {quantizer} is deprecated, consider moving to torchao quantizer" + ) + torch_ao_quantizers.append(quantizer) + + if torch_ao_quantizers and torchao_pt2e_quantizers: + raise ValueError("Mixed quantizer types are not supported") + if len(torch_ao_quantizers) > 1: + raise ValueError( + "Multiple quantizers of torch.ao.quantization.quantizer not supported" + ) + + if torch_ao_quantizers: + # prepare_pt2e has backward compat with torch.ao quantizer + return torch_ao_quantizers[0] + elif torchao_pt2e_quantizers: + # Multiple torchao quantizers - use ComposableQuantizer + return ComposableQuantizer(torchao_pt2e_quantizers) + else: + raise ValueError("No quantizers detected") + def run(self, artifact: PipelineArtifact) -> None: if not self._quantization_recipe or not self._quantization_recipe.quantizers: logging.info( @@ -357,11 +390,10 @@ def run(self, artifact: PipelineArtifact) -> None: inputs = example_inputs[method_name][0] captured_graph = torch.export.export(model, inputs, strict=True).module() - composed_quantizer = ComposableQuantizer( - # pyre-ignore + quantizer = self._get_quantizer_for_prepare_pt2e( self._quantization_recipe.quantizers ) - prepared_model = prepare_pt2e(captured_graph, composed_quantizer) + prepared_model = prepare_pt2e(captured_graph, quantizer) for calibration_input in example_inputs[method_name]: prepared_model(*calibration_input) diff --git a/export/tests/TARGETS b/export/tests/TARGETS index 068c3436b6a..56534140976 100644 --- a/export/tests/TARGETS +++ b/export/tests/TARGETS @@ -14,7 +14,7 @@ runtime.python_test( "//executorch/runtime:runtime", ] ) - +z runtime.python_test( name = "test_executorch_export", srcs = [ diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index 30288941d22..fcec1b7a59a 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -12,7 +12,11 @@ import torch from executorch.export import ExportRecipe, ExportSession -from executorch.export.recipe import LoweringRecipe, QuantizationRecipe +from executorch.export.recipe import ( + AOQuantizationConfig, + LoweringRecipe, + QuantizationRecipe, +) from executorch.export.stages import PipelineArtifact from executorch.export.types import StageType @@ -20,7 +24,7 @@ class SimpleTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.linear = torch.nn.Linear(10, 5) + self.linear: torch.nn.Module = torch.nn.Linear(10, 5) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -449,7 +453,7 @@ def test_pipeline_building_with_all_recipes(self) -> None: """Test pipeline building with quantization and lowering recipes.""" # Create comprehensive recipes quant_recipe = QuantizationRecipe( - ao_base_config=[Mock()], + ao_quantization_configs=[AOQuantizationConfig(Mock())], quantizers=[Mock()], ) lowering_recipe = LoweringRecipe( diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 4820e508e18..d4629a1aea7 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -11,25 +11,25 @@ import torch from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager -from executorch.export import QuantizationRecipe +from executorch.export import AOQuantizationConfig, QuantizationRecipe, StageType from executorch.export.stages import ( EdgeTransformAndLowerStage, ExecutorchStage, PipelineArtifact, QuantizeStage, SourceTransformStage, - StageType, ToBackendStage, ToEdgeStage, TorchExportStage, ) from torch.export import ExportedProgram +from torchao.quantization.pt2e.quantizer import Quantizer as TorchAOPT2EQuantizer class SimpleTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.linear = torch.nn.Linear(10, 5) + self.linear: torch.nn.Module = torch.nn.Linear(10, 5) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -163,7 +163,7 @@ def setUp(self) -> None: def test_source_transform_stage_no_quantization(self) -> None: mock_recipe = Mock(spec=QuantizationRecipe) - mock_recipe.ao_base_config = None + mock_recipe.ao_quantization_configs = None stage = SourceTransformStage(mock_recipe) artifact = PipelineArtifact(data=self.models_dict, context={}) @@ -174,12 +174,18 @@ def test_source_transform_stage_no_quantization(self) -> None: @patch("executorch.export.stages.quantize_") @patch("executorch.export.stages.unwrap_tensor_subclass") - def test_run_with_ao_base_config( + def test_run_with_ao_quantization_configs( self, mock_unwrap: Mock, mock_quantize: Mock ) -> None: - mock_config = Mock() + from torchao.core.config import AOBaseConfig + + mock_config = Mock(spec=AOBaseConfig) + mock_filter_fn = Mock() + mock_ao_config: AOQuantizationConfig = AOQuantizationConfig( + ao_base_config=mock_config, filter_fn=mock_filter_fn + ) mock_recipe = Mock(spec=QuantizationRecipe) - mock_recipe.ao_base_config = [mock_config] + mock_recipe.ao_quantization_configs = [mock_ao_config] stage = SourceTransformStage(mock_recipe) @@ -188,7 +194,7 @@ def test_run_with_ao_base_config( stage.run(artifact) # Verify quantize_ was called with the model and config - mock_quantize.assert_called_once_with(self.model, mock_config) + mock_quantize.assert_called_once_with(self.model, mock_config, mock_filter_fn) # Verify unwrap_tensor_subclass was called with the model mock_unwrap.assert_called_once_with(self.model) @@ -201,6 +207,21 @@ def setUp(self) -> None: self.example_inputs = [(torch.randn(2, 10),)] self.context = {"example_inputs": {"forward": self.example_inputs}} + @staticmethod + def create_dummy_quantizer() -> TorchAOPT2EQuantizer: + + class DummyQuantizer(TorchAOPT2EQuantizer): + def __init__(self): + pass + + def annotate(self, model): + return model + + def validate(self, model): + pass + + return DummyQuantizer() + def test_run_no_quantizers(self) -> None: """Test execution with no quantizers.""" mock_recipe = Mock(spec=QuantizationRecipe) @@ -224,7 +245,7 @@ def test_run_with_quantizers( mock_convert_pt2e: Mock, ) -> None: """Test execution with quantizers""" - mock_quantizer = Mock() + mock_quantizer = self.create_dummy_quantizer() mock_recipe = Mock(spec=QuantizationRecipe) mock_recipe.quantizers = [mock_quantizer] stage = QuantizeStage(mock_recipe) @@ -285,6 +306,35 @@ def test_run_empty_example_inputs(self) -> None: "Example inputs for method forward not found or empty", str(cm.exception) ) + @patch("executorch.export.stages.ComposableQuantizer") + def test_get_quantizer_for_prepare_pt2e( + self, mock_composable_quantizer: Mock + ) -> None: + """Test _get_quantizer_for_prepare_pt2e method with different quantizer scenarios.""" + mock_recipe = Mock(spec=QuantizationRecipe) + stage = QuantizeStage(mock_recipe) + + # Test empty quantizers list - should raise ValueError + with self.assertRaises(ValueError) as cm: + stage._get_quantizer_for_prepare_pt2e([]) + self.assertIn("No quantizers detected", str(cm.exception)) + + # Test ComposableQuantizer path with multiple torchao quantizers + # Create instances of dummy quantizers using the reusable method + quantizer1 = self.create_dummy_quantizer() + quantizer2 = self.create_dummy_quantizer() + + # Set up ComposableQuantizer mock + mock_composed_quantizer = Mock() + mock_composable_quantizer.return_value = mock_composed_quantizer + + # Call the method with multiple torchao quantizers + result = stage._get_quantizer_for_prepare_pt2e([quantizer1, quantizer2]) + + # Verify ComposableQuantizer was called with the quantizers + mock_composable_quantizer.assert_called_once_with([quantizer1, quantizer2]) + self.assertEqual(result, mock_composed_quantizer) + class TestToEdgeStage(unittest.TestCase): def setUp(self) -> None: