Skip to content

Add TorchAO wrapper config to allow filter_fn for quantize_ #13264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions backends/xnnpack/recipes/xnnpack_recipe_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_xnnpack_executorch_backend_config,
)
from executorch.export import (
AOQuantizationConfig,
BackendRecipeProvider,
ExportRecipe,
LoweringRecipe,
Expand Down Expand Up @@ -144,14 +145,16 @@ def _build_int8da_intx_weight_recipe(
else:
weight_granularity = PerGroup(group_size=group_size)

config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype,
weight_granularity=weight_granularity,
config = AOQuantizationConfig(
Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype,
weight_granularity=weight_granularity,
)
)

quant_recipe = QuantizationRecipe(
quantizers=None,
ao_base_config=[config],
ao_quantization_configs=[config],
)

return ExportRecipe(
Expand Down
48 changes: 22 additions & 26 deletions backends/xnnpack/test/recipes/test_xnnpack_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
from executorch.exir.schema import DelegateCall, Program
from executorch.export import export, ExportRecipe, recipe_registry
from export.types import StageType
from torch import nn
from torch.testing._internal.common_quantization import TestHelperModules

Expand All @@ -38,6 +39,19 @@ def check_fully_delegated(self, program: Program) -> None:
self.assertEqual(len(instructions), 1)
self.assertIsInstance(instructions[0].instr_args, DelegateCall)

# pyre-ignore
def _compare_eager_quantized_model_outputs(
self, session, example_inputs, atol: float
) -> None:
"""Utility to compare eager quantized model output with session output after coreml lowering"""
source_transform_output = session.get_stage_artifacts()[
StageType.SOURCE_TRANSFORM
]
eager_quantized_model = source_transform_output.data["forward"]
output = session.run_method("forward", example_inputs[0])[0]
expected = eager_quantized_model(*example_inputs[0])
self.assertTrue(torch.allclose(output, expected, atol=atol))

def test_basic_recipe(self) -> None:
m_eager = TestHelperModules.TwoLinearModule().eval()
example_inputs = [(torch.randn(9, 8),)]
Expand All @@ -46,13 +60,7 @@ def test_basic_recipe(self) -> None:
example_inputs=example_inputs,
export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32),
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-3,
)
)
self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3)
self.check_fully_delegated(session.get_executorch_program())

def test_int8_dynamic_quant_recipe(self) -> None:
Expand All @@ -70,12 +78,8 @@ def test_int8_dynamic_quant_recipe(self) -> None:
example_inputs=example_inputs,
export_recipe=export_recipe,
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-1,
)
self._compare_eager_quantized_model_outputs(
session, example_inputs, 1e-1
)
self.check_fully_delegated(session.get_executorch_program())

Expand All @@ -95,12 +99,8 @@ def test_int8_static_quant_recipe(self) -> None:
example_inputs=example_inputs,
export_recipe=export_recipe,
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-1,
)
self._compare_eager_quantized_model_outputs(
session, example_inputs, 1e-1
)
self.check_fully_delegated(session.get_executorch_program())

Expand Down Expand Up @@ -133,14 +133,10 @@ def forward(self, x) -> torch.Tensor:
example_inputs=example_inputs,
export_recipe=export_recipe,
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
model(*example_inputs[0]),
atol=1e-2,
)
)
self.check_fully_delegated(session.get_executorch_program())
self._compare_eager_quantized_model_outputs(
session, example_inputs, 1e-2
)

def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType:
# Map QuantType to corresponding recipe name.
Expand Down
9 changes: 8 additions & 1 deletion export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
"""

from .export import export, ExportSession
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
from .recipe import (
AOQuantizationConfig,
ExportRecipe,
LoweringRecipe,
QuantizationRecipe,
RecipeType,
)
from .recipe_provider import BackendRecipeProvider
from .recipe_registry import recipe_registry
from .types import StageType

__all__ = [
"AOQuantizationConfig",
"StageType",
"ExportRecipe",
"LoweringRecipe",
Expand Down
23 changes: 20 additions & 3 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from enum import Enum, EnumMeta
from typing import List, Optional, Sequence
from typing import Callable, List, Optional, Sequence

import torch

from executorch.exir._warnings import experimental

Expand Down Expand Up @@ -64,6 +66,20 @@ class Mode(str, Enum):
RELEASE = "release"


@dataclass
class AOQuantizationConfig:
"""
Configuration for torchao quantization with optional filter function.

Attributes:
ao_base_config: The AOBaseConfig for quantization
filter_fn: Optional filter function to selectively apply quantization
"""

ao_base_config: AOBaseConfig
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None


@dataclass
class QuantizationRecipe:
"""
Expand All @@ -73,11 +89,12 @@ class QuantizationRecipe:

Attributes:
quantizers: Optional list of quantizers for model quantization
ao_base_config: Optional list of AO base configurations
ao_quantization_configs: Optional list of AOQuantizationConfig objects that pair
AOBaseConfig with optional filter functions
"""

quantizers: Optional[List[Quantizer]] = None
ao_base_config: Optional[List[AOBaseConfig]] = None
ao_quantization_configs: Optional[List[AOQuantizationConfig]] = None

def get_quantizers(self) -> Optional[List[Quantizer]]:
"""
Expand Down
50 changes: 40 additions & 10 deletions export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Sequence
Expand All @@ -20,7 +21,6 @@
from torch._export.pass_base import PassType
from torchao.quantization import quantize_
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import ComposableQuantizer
from torchao.utils import unwrap_tensor_subclass


Expand Down Expand Up @@ -287,7 +287,7 @@ def run(self, artifact: PipelineArtifact) -> None:
"""
if (
not self._quantization_recipe
or not self._quantization_recipe.ao_base_config
or not self._quantization_recipe.ao_quantization_configs
):
logging.info(
"Quantization recipe is invalid to run SourceTransform, returning original artifact"
Expand All @@ -298,15 +298,14 @@ def run(self, artifact: PipelineArtifact) -> None:
assert isinstance(artifact.data, dict)

# Store the original models
self._transformed_models = artifact.data
self._transformed_models = copy.deepcopy(artifact.data)

# Apply torchao quantize_ to each model
for method_name, model in artifact.data.items():
for _, model in artifact.data.items():
# pyre-ignore
for config in self._quantization_recipe.ao_base_config:
quantize_(model, config)
for ao_config in self._quantization_recipe.ao_quantization_configs:
quantize_(model, ao_config.ao_base_config, ao_config.filter_fn)
unwrap_tensor_subclass(model)
self._transformed_models[method_name] = model

self._artifact = artifact.copy_with_new_data(self._transformed_models)

Expand All @@ -331,6 +330,38 @@ def valid_predecessor_stages(self) -> List["StageType"]:
def can_start_pipeline(self) -> bool:
return True

def _get_quantizer_for_prepare_pt2e(self, quantizers: List[Any]):
torch_ao_quantizers = []
torchao_pt2e_quantizers = []

for quantizer in quantizers:
from torchao.quantization.pt2e.quantizer import (
Quantizer as TorchAOPT2EQuantizer,
)

if isinstance(quantizer, TorchAOPT2EQuantizer):
torchao_pt2e_quantizers.append(quantizer)
else:
torch_ao_quantizers.append(quantizer)

if torch_ao_quantizers and torchao_pt2e_quantizers:
raise ValueError("Mixed quantizer types are not supported")
if len(torch_ao_quantizers) > 1:
raise ValueError(
"Multiple quantizers of torch.ao.quantization.quantizer not supported"
)

if torch_ao_quantizers:
# prepare_pt2e has backward compat with torch.ao quantizer
return torch_ao_quantizers[0]
elif torchao_pt2e_quantizers:
# Multiple torchao quantizers - use ComposableQuantizer
from torchao.quantization.pt2e.quantizer import ComposableQuantizer

return ComposableQuantizer(torchao_pt2e_quantizers)
else:
raise ValueError("No quantizers detected")

def run(self, artifact: PipelineArtifact) -> None:
if not self._quantization_recipe or not self._quantization_recipe.quantizers:
logging.info(
Expand All @@ -355,11 +386,10 @@ def run(self, artifact: PipelineArtifact) -> None:
inputs = example_inputs[method_name][0]
captured_graph = torch.export.export(model, inputs, strict=True).module()

composed_quantizer = ComposableQuantizer(
# pyre-ignore
quantizer = self._get_quantizer_for_prepare_pt2e(
self._quantization_recipe.quantizers
)
prepared_model = prepare_pt2e(captured_graph, composed_quantizer)
prepared_model = prepare_pt2e(captured_graph, quantizer)

for calibration_input in example_inputs[method_name]:
prepared_model(*calibration_input)
Expand Down
10 changes: 7 additions & 3 deletions export/tests/test_export_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@

import torch
from executorch.export import ExportRecipe, ExportSession
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
from executorch.export.recipe import (
AOQuantizationConfig,
LoweringRecipe,
QuantizationRecipe,
)
from executorch.export.stages import PipelineArtifact
from executorch.export.types import StageType


class SimpleTestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 5)
self.linear: torch.nn.Module = torch.nn.Linear(10, 5)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
Expand Down Expand Up @@ -448,7 +452,7 @@ def test_pipeline_building_with_all_recipes(self) -> None:
"""Test pipeline building with quantization and lowering recipes."""
# Create comprehensive recipes
quant_recipe = QuantizationRecipe(
ao_base_config=[Mock()],
ao_quantization_configs=[AOQuantizationConfig(Mock())],
quantizers=[Mock()],
)
lowering_recipe = LoweringRecipe(
Expand Down
21 changes: 14 additions & 7 deletions export/tests/test_export_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager
from executorch.export import QuantizationRecipe
from executorch.export import AOQuantizationConfig, QuantizationRecipe
from executorch.export.stages import (
EdgeTransformAndLowerStage,
ExecutorchStage,
Expand All @@ -29,7 +29,7 @@
class SimpleTestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 5)
self.linear: torch.nn.Module = torch.nn.Linear(10, 5)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
Expand Down Expand Up @@ -163,7 +163,7 @@ def setUp(self) -> None:

def test_source_transform_stage_no_quantization(self) -> None:
mock_recipe = Mock(spec=QuantizationRecipe)
mock_recipe.ao_base_config = None
mock_recipe.ao_quantization_configs = None
stage = SourceTransformStage(mock_recipe)
artifact = PipelineArtifact(data=self.models_dict, context={})

Expand All @@ -174,12 +174,19 @@ def test_source_transform_stage_no_quantization(self) -> None:

@patch("executorch.export.stages.quantize_")
@patch("executorch.export.stages.unwrap_tensor_subclass")
def test_run_with_ao_base_config(
def test_run_with_ao_quantization_configs(
self, mock_unwrap: Mock, mock_quantize: Mock
) -> None:
mock_config = Mock()
from torchao.core.config import AOBaseConfig

mock_config = Mock(spec=AOBaseConfig)
mock_filter_fn = Mock()
# pyre-ignore[28]: Unexpected keyword argument error is a false positive for dataclass
mock_ao_config: AOQuantizationConfig = AOQuantizationConfig(
ao_base_config=mock_config, filter_fn=mock_filter_fn
)
mock_recipe = Mock(spec=QuantizationRecipe)
mock_recipe.ao_base_config = [mock_config]
mock_recipe.ao_quantization_configs = [mock_ao_config]

stage = SourceTransformStage(mock_recipe)

Expand All @@ -188,7 +195,7 @@ def test_run_with_ao_base_config(
stage.run(artifact)

# Verify quantize_ was called with the model and config
mock_quantize.assert_called_once_with(self.model, mock_config)
mock_quantize.assert_called_once_with(self.model, mock_config, mock_filter_fn)

# Verify unwrap_tensor_subclass was called with the model
mock_unwrap.assert_called_once_with(self.model)
Expand Down
Loading