Skip to content

Commit 084517e

Browse files
Add TorchAO wrapper config to allow filter_fn for quantize_
ghstack-source-id: 8cd3819 ghstack-comment-id: 3172341537 Pull-Request: #13264
1 parent a84b3c9 commit 084517e

File tree

7 files changed

+118
-54
lines changed

7 files changed

+118
-54
lines changed

backends/xnnpack/recipes/xnnpack_recipe_provider.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_xnnpack_executorch_backend_config,
2626
)
2727
from executorch.export import (
28+
AOQuantizationConfig,
2829
BackendRecipeProvider,
2930
ExportRecipe,
3031
LoweringRecipe,
@@ -144,14 +145,16 @@ def _build_int8da_intx_weight_recipe(
144145
else:
145146
weight_granularity = PerGroup(group_size=group_size)
146147

147-
config = Int8DynamicActivationIntxWeightConfig(
148-
weight_dtype=weight_dtype,
149-
weight_granularity=weight_granularity,
148+
config = AOQuantizationConfig(
149+
Int8DynamicActivationIntxWeightConfig(
150+
weight_dtype=weight_dtype,
151+
weight_granularity=weight_granularity,
152+
)
150153
)
151154

152155
quant_recipe = QuantizationRecipe(
153156
quantizers=None,
154-
ao_base_config=[config],
157+
ao_quantization_configs=[config],
155158
)
156159

157160
return ExportRecipe(

backends/xnnpack/test/recipes/test_xnnpack_recipes.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
2020
from executorch.exir.schema import DelegateCall, Program
2121
from executorch.export import export, ExportRecipe, recipe_registry
22+
from export.types import StageType
2223
from torch import nn
2324
from torch.testing._internal.common_quantization import TestHelperModules
2425

@@ -38,6 +39,19 @@ def check_fully_delegated(self, program: Program) -> None:
3839
self.assertEqual(len(instructions), 1)
3940
self.assertIsInstance(instructions[0].instr_args, DelegateCall)
4041

42+
# pyre-ignore
43+
def _compare_eager_quantized_model_outputs(
44+
self, session, example_inputs, atol: float
45+
) -> None:
46+
"""Utility to compare eager quantized model output with session output after coreml lowering"""
47+
source_transform_output = session.get_stage_artifacts()[
48+
StageType.SOURCE_TRANSFORM
49+
]
50+
eager_quantized_model = source_transform_output.data["forward"]
51+
output = session.run_method("forward", example_inputs[0])[0]
52+
expected = eager_quantized_model(*example_inputs[0])
53+
self.assertTrue(torch.allclose(output, expected, atol=atol))
54+
4155
def test_basic_recipe(self) -> None:
4256
m_eager = TestHelperModules.TwoLinearModule().eval()
4357
example_inputs = [(torch.randn(9, 8),)]
@@ -46,13 +60,7 @@ def test_basic_recipe(self) -> None:
4660
example_inputs=example_inputs,
4761
export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32),
4862
)
49-
self.assertTrue(
50-
torch.allclose(
51-
session.run_method("forward", example_inputs[0])[0],
52-
m_eager(*example_inputs[0]),
53-
atol=1e-3,
54-
)
55-
)
63+
self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3)
5664
self.check_fully_delegated(session.get_executorch_program())
5765

5866
def test_int8_dynamic_quant_recipe(self) -> None:
@@ -70,12 +78,8 @@ def test_int8_dynamic_quant_recipe(self) -> None:
7078
example_inputs=example_inputs,
7179
export_recipe=export_recipe,
7280
)
73-
self.assertTrue(
74-
torch.allclose(
75-
session.run_method("forward", example_inputs[0])[0],
76-
m_eager(*example_inputs[0]),
77-
atol=1e-1,
78-
)
81+
self._compare_eager_quantized_model_outputs(
82+
session, example_inputs, 1e-2
7983
)
8084
self.check_fully_delegated(session.get_executorch_program())
8185

@@ -95,12 +99,8 @@ def test_int8_static_quant_recipe(self) -> None:
9599
example_inputs=example_inputs,
96100
export_recipe=export_recipe,
97101
)
98-
self.assertTrue(
99-
torch.allclose(
100-
session.run_method("forward", example_inputs[0])[0],
101-
m_eager(*example_inputs[0]),
102-
atol=1e-1,
103-
)
102+
self._compare_eager_quantized_model_outputs(
103+
session, example_inputs, 1e-1
104104
)
105105
self.check_fully_delegated(session.get_executorch_program())
106106

@@ -133,14 +133,10 @@ def forward(self, x) -> torch.Tensor:
133133
example_inputs=example_inputs,
134134
export_recipe=export_recipe,
135135
)
136-
self.assertTrue(
137-
torch.allclose(
138-
session.run_method("forward", example_inputs[0])[0],
139-
model(*example_inputs[0]),
140-
atol=1e-2,
141-
)
142-
)
143136
self.check_fully_delegated(session.get_executorch_program())
137+
self._compare_eager_quantized_model_outputs(
138+
session, example_inputs, 1e-2
139+
)
144140

145141
def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType:
146142
# Map QuantType to corresponding recipe name.

export/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
"""
1616

1717
from .export import export, ExportSession
18-
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
18+
from .recipe import (
19+
AOQuantizationConfig,
20+
ExportRecipe,
21+
LoweringRecipe,
22+
QuantizationRecipe,
23+
RecipeType,
24+
)
1925
from .recipe_provider import BackendRecipeProvider
2026
from .recipe_registry import recipe_registry
2127
from .types import StageType
2228

2329
__all__ = [
30+
"AOQuantizationConfig",
2431
"StageType",
2532
"ExportRecipe",
2633
"LoweringRecipe",

export/recipe.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from abc import ABCMeta, abstractmethod
77
from dataclasses import dataclass
88
from enum import Enum, EnumMeta
9-
from typing import List, Optional, Sequence
9+
from typing import Callable, List, Optional, Sequence
10+
11+
import torch
1012

1113
from executorch.exir._warnings import experimental
1214

@@ -64,6 +66,20 @@ class Mode(str, Enum):
6466
RELEASE = "release"
6567

6668

69+
@dataclass
70+
class AOQuantizationConfig:
71+
"""
72+
Configuration for torchao quantization with optional filter function.
73+
74+
Attributes:
75+
ao_base_config: The AOBaseConfig for quantization
76+
filter_fn: Optional filter function to selectively apply quantization
77+
"""
78+
79+
ao_base_config: AOBaseConfig
80+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None
81+
82+
6783
@dataclass
6884
class QuantizationRecipe:
6985
"""
@@ -73,11 +89,12 @@ class QuantizationRecipe:
7389
7490
Attributes:
7591
quantizers: Optional list of quantizers for model quantization
76-
ao_base_config: Optional list of AO base configurations
92+
ao_quantization_configs: Optional list of AOQuantizationConfig objects that pair
93+
AOBaseConfig with optional filter functions
7794
"""
7895

7996
quantizers: Optional[List[Quantizer]] = None
80-
ao_base_config: Optional[List[AOBaseConfig]] = None
97+
ao_quantization_configs: Optional[List[AOQuantizationConfig]] = None
8198

8299
def get_quantizers(self) -> Optional[List[Quantizer]]:
83100
"""

export/stages.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import copy
78
import logging
89
from abc import ABC, abstractmethod
910
from typing import Any, Callable, Dict, List, Optional, Sequence
@@ -20,7 +21,6 @@
2021
from torch._export.pass_base import PassType
2122
from torchao.quantization import quantize_
2223
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
23-
from torchao.quantization.pt2e.quantizer import ComposableQuantizer
2424
from torchao.utils import unwrap_tensor_subclass
2525

2626

@@ -287,7 +287,7 @@ def run(self, artifact: PipelineArtifact) -> None:
287287
"""
288288
if (
289289
not self._quantization_recipe
290-
or not self._quantization_recipe.ao_base_config
290+
or not self._quantization_recipe.ao_quantization_configs
291291
):
292292
logging.info(
293293
"Quantization recipe is invalid to run SourceTransform, returning original artifact"
@@ -298,15 +298,14 @@ def run(self, artifact: PipelineArtifact) -> None:
298298
assert isinstance(artifact.data, dict)
299299

300300
# Store the original models
301-
self._transformed_models = artifact.data
301+
self._transformed_models = copy.deepcopy(artifact.data)
302302

303303
# Apply torchao quantize_ to each model
304-
for method_name, model in artifact.data.items():
304+
for _, model in artifact.data.items():
305305
# pyre-ignore
306-
for config in self._quantization_recipe.ao_base_config:
307-
quantize_(model, config)
306+
for ao_config in self._quantization_recipe.ao_quantization_configs:
307+
quantize_(model, ao_config.ao_base_config, ao_config.filter_fn)
308308
unwrap_tensor_subclass(model)
309-
self._transformed_models[method_name] = model
310309

311310
self._artifact = artifact.copy_with_new_data(self._transformed_models)
312311

@@ -331,6 +330,38 @@ def valid_predecessor_stages(self) -> List["StageType"]:
331330
def can_start_pipeline(self) -> bool:
332331
return True
333332

333+
def _get_quantizer_for_prepare_pt2e(self, quantizers: List[Any]):
334+
torch_ao_quantizers = []
335+
torchao_pt2e_quantizers = []
336+
337+
for quantizer in quantizers:
338+
from torchao.quantization.pt2e.quantizer import (
339+
Quantizer as TorchAOPT2EQuantizer,
340+
)
341+
342+
if isinstance(quantizer, TorchAOPT2EQuantizer):
343+
torchao_pt2e_quantizers.append(quantizer)
344+
else:
345+
torch_ao_quantizers.append(quantizer)
346+
347+
if torch_ao_quantizers and torchao_pt2e_quantizers:
348+
raise ValueError("Mixed quantizer types are not supported")
349+
if len(torch_ao_quantizers) > 1:
350+
raise ValueError(
351+
"Multiple quantizers of torch.ao.quantization.quantizer not supported"
352+
)
353+
354+
if torch_ao_quantizers:
355+
# prepare_pt2e has backward compat with torch.ao quantizer
356+
return torch_ao_quantizers[0]
357+
elif torchao_pt2e_quantizers:
358+
# Multiple torchao quantizers - use ComposableQuantizer
359+
from torchao.quantization.pt2e.quantizer import ComposableQuantizer
360+
361+
return ComposableQuantizer(torchao_pt2e_quantizers)
362+
else:
363+
raise ValueError("No quantizers detected")
364+
334365
def run(self, artifact: PipelineArtifact) -> None:
335366
if not self._quantization_recipe or not self._quantization_recipe.quantizers:
336367
logging.info(
@@ -355,11 +386,10 @@ def run(self, artifact: PipelineArtifact) -> None:
355386
inputs = example_inputs[method_name][0]
356387
captured_graph = torch.export.export(model, inputs, strict=True).module()
357388

358-
composed_quantizer = ComposableQuantizer(
359-
# pyre-ignore
389+
quantizer = self._get_quantizer_for_prepare_pt2e(
360390
self._quantization_recipe.quantizers
361391
)
362-
prepared_model = prepare_pt2e(captured_graph, composed_quantizer)
392+
prepared_model = prepare_pt2e(captured_graph, quantizer)
363393

364394
for calibration_input in example_inputs[method_name]:
365395
prepared_model(*calibration_input)

export/tests/test_export_session.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,19 @@
1212

1313
import torch
1414
from executorch.export import ExportRecipe, ExportSession
15-
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
15+
from executorch.export.recipe import (
16+
AOQuantizationConfig,
17+
LoweringRecipe,
18+
QuantizationRecipe,
19+
)
1620
from executorch.export.stages import PipelineArtifact
1721
from executorch.export.types import StageType
1822

1923

2024
class SimpleTestModel(torch.nn.Module):
2125
def __init__(self) -> None:
2226
super().__init__()
23-
self.linear = torch.nn.Linear(10, 5)
27+
self.linear: torch.nn.Module = torch.nn.Linear(10, 5)
2428

2529
def forward(self, x: torch.Tensor) -> torch.Tensor:
2630
return self.linear(x)
@@ -448,7 +452,7 @@ def test_pipeline_building_with_all_recipes(self) -> None:
448452
"""Test pipeline building with quantization and lowering recipes."""
449453
# Create comprehensive recipes
450454
quant_recipe = QuantizationRecipe(
451-
ao_base_config=[Mock()],
455+
ao_quantization_configs=[AOQuantizationConfig(Mock())],
452456
quantizers=[Mock()],
453457
)
454458
lowering_recipe = LoweringRecipe(

export/tests/test_export_stages.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313
from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager
14-
from executorch.export import QuantizationRecipe
14+
from executorch.export import AOQuantizationConfig, QuantizationRecipe
1515
from executorch.export.stages import (
1616
EdgeTransformAndLowerStage,
1717
ExecutorchStage,
@@ -29,7 +29,7 @@
2929
class SimpleTestModel(torch.nn.Module):
3030
def __init__(self) -> None:
3131
super().__init__()
32-
self.linear = torch.nn.Linear(10, 5)
32+
self.linear: torch.nn.Module = torch.nn.Linear(10, 5)
3333

3434
def forward(self, x: torch.Tensor) -> torch.Tensor:
3535
return self.linear(x)
@@ -163,7 +163,7 @@ def setUp(self) -> None:
163163

164164
def test_source_transform_stage_no_quantization(self) -> None:
165165
mock_recipe = Mock(spec=QuantizationRecipe)
166-
mock_recipe.ao_base_config = None
166+
mock_recipe.ao_quantization_configs = None
167167
stage = SourceTransformStage(mock_recipe)
168168
artifact = PipelineArtifact(data=self.models_dict, context={})
169169

@@ -174,12 +174,19 @@ def test_source_transform_stage_no_quantization(self) -> None:
174174

175175
@patch("executorch.export.stages.quantize_")
176176
@patch("executorch.export.stages.unwrap_tensor_subclass")
177-
def test_run_with_ao_base_config(
177+
def test_run_with_ao_quantization_configs(
178178
self, mock_unwrap: Mock, mock_quantize: Mock
179179
) -> None:
180-
mock_config = Mock()
180+
from torchao.core.config import AOBaseConfig
181+
182+
mock_config = Mock(spec=AOBaseConfig)
183+
mock_filter_fn = Mock()
184+
# pyre-ignore[28]: Unexpected keyword argument error is a false positive for dataclass
185+
mock_ao_config: AOQuantizationConfig = AOQuantizationConfig(
186+
ao_base_config=mock_config, filter_fn=mock_filter_fn
187+
)
181188
mock_recipe = Mock(spec=QuantizationRecipe)
182-
mock_recipe.ao_base_config = [mock_config]
189+
mock_recipe.ao_quantization_configs = [mock_ao_config]
183190

184191
stage = SourceTransformStage(mock_recipe)
185192

@@ -188,7 +195,7 @@ def test_run_with_ao_base_config(
188195
stage.run(artifact)
189196

190197
# Verify quantize_ was called with the model and config
191-
mock_quantize.assert_called_once_with(self.model, mock_config)
198+
mock_quantize.assert_called_once_with(self.model, mock_config, mock_filter_fn)
192199

193200
# Verify unwrap_tensor_subclass was called with the model
194201
mock_unwrap.assert_called_once_with(self.model)

0 commit comments

Comments
 (0)