Skip to content

Commit cb80c69

Browse files
Add TorchAO wrapper config to allow filter_fn for quantize_
ghstack-source-id: 0216ca1 ghstack-comment-id: 3172341537 Pull-Request: #13264
1 parent 0c1acb3 commit cb80c69

File tree

5 files changed

+88
-22
lines changed

5 files changed

+88
-22
lines changed

export/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
"""
1616

1717
from .export import export, ExportSession
18-
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
18+
from .recipe import (
19+
AOQuantizationConfig,
20+
ExportRecipe,
21+
LoweringRecipe,
22+
QuantizationRecipe,
23+
RecipeType,
24+
)
1925
from .recipe_provider import BackendRecipeProvider
2026
from .recipe_registry import recipe_registry
2127
from .types import StageType
2228

2329
__all__ = [
30+
"AOQuantizationConfig",
2431
"StageType",
2532
"ExportRecipe",
2633
"LoweringRecipe",

export/recipe.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from abc import ABCMeta, abstractmethod
77
from dataclasses import dataclass
88
from enum import Enum, EnumMeta
9-
from typing import List, Optional, Sequence
9+
from typing import Callable, List, Optional, Sequence
10+
11+
import torch
1012

1113
from executorch.exir._warnings import experimental
1214

@@ -64,6 +66,20 @@ class Mode(str, Enum):
6466
RELEASE = "release"
6567

6668

69+
@dataclass
70+
class AOQuantizationConfig:
71+
"""
72+
Configuration for torchao quantization with optional filter function.
73+
74+
Attributes:
75+
ao_base_config: The AOBaseConfig for quantization
76+
filter_fn: Optional filter function to selectively apply quantization
77+
"""
78+
79+
ao_base_config: AOBaseConfig
80+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None
81+
82+
6783
@dataclass
6884
class QuantizationRecipe:
6985
"""
@@ -73,11 +89,12 @@ class QuantizationRecipe:
7389
7490
Attributes:
7591
quantizers: Optional list of quantizers for model quantization
76-
ao_base_config: Optional list of AO base configurations
92+
ao_quantization_configs: Optional list of AOQuantizationConfig objects that pair
93+
AOBaseConfig with optional filter functions
7794
"""
7895

7996
quantizers: Optional[List[Quantizer]] = None
80-
ao_base_config: Optional[List[AOBaseConfig]] = None
97+
ao_quantization_configs: Optional[List[AOQuantizationConfig]] = None
8198

8299
def get_quantizers(self) -> Optional[List[Quantizer]]:
83100
"""

export/stages.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torch._export.pass_base import PassType
2121
from torchao.quantization import quantize_
2222
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
23-
from torchao.quantization.pt2e.quantizer import ComposableQuantizer
2423
from torchao.utils import unwrap_tensor_subclass
2524

2625

@@ -287,7 +286,7 @@ def run(self, artifact: PipelineArtifact) -> None:
287286
"""
288287
if (
289288
not self._quantization_recipe
290-
or not self._quantization_recipe.ao_base_config
289+
or not self._quantization_recipe.ao_quantization_configs
291290
):
292291
logging.info(
293292
"Quantization recipe is invalid to run SourceTransform, returning original artifact"
@@ -303,10 +302,11 @@ def run(self, artifact: PipelineArtifact) -> None:
303302
# Apply torchao quantize_ to each model
304303
for method_name, model in artifact.data.items():
305304
# pyre-ignore
306-
for config in self._quantization_recipe.ao_base_config:
307-
quantize_(model, config)
305+
for ao_config in self._quantization_recipe.ao_quantization_configs:
306+
quantize_(model, ao_config.ao_base_config, ao_config.filter_fn)
308307
unwrap_tensor_subclass(model)
309-
self._transformed_models[method_name] = model
308+
309+
self._transformed_models[method_name] = model
310310

311311
self._artifact = artifact.copy_with_new_data(self._transformed_models)
312312

@@ -331,6 +331,38 @@ def valid_predecessor_stages(self) -> List["StageType"]:
331331
def can_start_pipeline(self) -> bool:
332332
return True
333333

334+
def _get_quantizer_for_prepare_pt2e(self, quantizers: List[Any]):
335+
torch_ao_quantizers = []
336+
torchao_pt2e_quantizers = []
337+
338+
for quantizer in quantizers:
339+
from torchao.quantization.pt2e.quantizer import (
340+
Quantizer as TorchAOPT2EQuantizer,
341+
)
342+
343+
if isinstance(quantizer, TorchAOPT2EQuantizer):
344+
torchao_pt2e_quantizers.append(quantizer)
345+
else:
346+
torch_ao_quantizers.append(quantizer)
347+
348+
if torch_ao_quantizers and torchao_pt2e_quantizers:
349+
raise ValueError("Mixed quantizer types are not supported")
350+
if len(torch_ao_quantizers) > 1:
351+
raise ValueError(
352+
"Multiple quantizers of torch.ao.quantization.quantizer not supported"
353+
)
354+
355+
if torch_ao_quantizers:
356+
# prepare_pt2e has backward compat with torch.ao quantizer
357+
return torch_ao_quantizers[0]
358+
elif torchao_pt2e_quantizers:
359+
# Multiple torchao quantizers - use ComposableQuantizer
360+
from torchao.quantization.pt2e.quantizer import ComposableQuantizer
361+
362+
return ComposableQuantizer(torchao_pt2e_quantizers)
363+
else:
364+
raise ValueError("No quantizers detected")
365+
334366
def run(self, artifact: PipelineArtifact) -> None:
335367
if not self._quantization_recipe or not self._quantization_recipe.quantizers:
336368
logging.info(
@@ -355,11 +387,10 @@ def run(self, artifact: PipelineArtifact) -> None:
355387
inputs = example_inputs[method_name][0]
356388
captured_graph = torch.export.export(model, inputs, strict=True).module()
357389

358-
composed_quantizer = ComposableQuantizer(
359-
# pyre-ignore
390+
quantizer = self._get_quantizer_for_prepare_pt2e(
360391
self._quantization_recipe.quantizers
361392
)
362-
prepared_model = prepare_pt2e(captured_graph, composed_quantizer)
393+
prepared_model = prepare_pt2e(captured_graph, quantizer)
363394

364395
for calibration_input in example_inputs[method_name]:
365396
prepared_model(*calibration_input)

export/tests/test_export_session.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,19 @@
1212

1313
import torch
1414
from executorch.export import ExportRecipe, ExportSession
15-
from executorch.export.recipe import LoweringRecipe, QuantizationRecipe
15+
from executorch.export.recipe import (
16+
AOQuantizationConfig,
17+
LoweringRecipe,
18+
QuantizationRecipe,
19+
)
1620
from executorch.export.stages import PipelineArtifact
1721
from executorch.export.types import StageType
1822

1923

2024
class SimpleTestModel(torch.nn.Module):
2125
def __init__(self) -> None:
2226
super().__init__()
23-
self.linear = torch.nn.Linear(10, 5)
27+
self.linear: torch.nn.Module = torch.nn.Linear(10, 5)
2428

2529
def forward(self, x: torch.Tensor) -> torch.Tensor:
2630
return self.linear(x)
@@ -448,7 +452,7 @@ def test_pipeline_building_with_all_recipes(self) -> None:
448452
"""Test pipeline building with quantization and lowering recipes."""
449453
# Create comprehensive recipes
450454
quant_recipe = QuantizationRecipe(
451-
ao_base_config=[Mock()],
455+
ao_quantization_configs=[AOQuantizationConfig(Mock())],
452456
quantizers=[Mock()],
453457
)
454458
lowering_recipe = LoweringRecipe(

export/tests/test_export_stages.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313
from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager
14-
from executorch.export import QuantizationRecipe
14+
from executorch.export import AOQuantizationConfig, QuantizationRecipe
1515
from executorch.export.stages import (
1616
EdgeTransformAndLowerStage,
1717
ExecutorchStage,
@@ -29,7 +29,7 @@
2929
class SimpleTestModel(torch.nn.Module):
3030
def __init__(self) -> None:
3131
super().__init__()
32-
self.linear = torch.nn.Linear(10, 5)
32+
self.linear: torch.nn.Module = torch.nn.Linear(10, 5)
3333

3434
def forward(self, x: torch.Tensor) -> torch.Tensor:
3535
return self.linear(x)
@@ -163,7 +163,7 @@ def setUp(self) -> None:
163163

164164
def test_source_transform_stage_no_quantization(self) -> None:
165165
mock_recipe = Mock(spec=QuantizationRecipe)
166-
mock_recipe.ao_base_config = None
166+
mock_recipe.ao_quantization_configs = None
167167
stage = SourceTransformStage(mock_recipe)
168168
artifact = PipelineArtifact(data=self.models_dict, context={})
169169

@@ -174,12 +174,19 @@ def test_source_transform_stage_no_quantization(self) -> None:
174174

175175
@patch("executorch.export.stages.quantize_")
176176
@patch("executorch.export.stages.unwrap_tensor_subclass")
177-
def test_run_with_ao_base_config(
177+
def test_run_with_ao_quantization_configs(
178178
self, mock_unwrap: Mock, mock_quantize: Mock
179179
) -> None:
180-
mock_config = Mock()
180+
from torchao.core.config import AOBaseConfig
181+
182+
mock_config = Mock(spec=AOBaseConfig)
183+
mock_filter_fn = Mock()
184+
# pyre-ignore[28]: Unexpected keyword argument error is a false positive for dataclass
185+
mock_ao_config: AOQuantizationConfig = AOQuantizationConfig(
186+
ao_base_config=mock_config, filter_fn=mock_filter_fn
187+
)
181188
mock_recipe = Mock(spec=QuantizationRecipe)
182-
mock_recipe.ao_base_config = [mock_config]
189+
mock_recipe.ao_quantization_configs = [mock_ao_config]
183190

184191
stage = SourceTransformStage(mock_recipe)
185192

@@ -188,7 +195,7 @@ def test_run_with_ao_base_config(
188195
stage.run(artifact)
189196

190197
# Verify quantize_ was called with the model and config
191-
mock_quantize.assert_called_once_with(self.model, mock_config)
198+
mock_quantize.assert_called_once_with(self.model, mock_config, mock_filter_fn)
192199

193200
# Verify unwrap_tensor_subclass was called with the model
194201
mock_unwrap.assert_called_once_with(self.model)

0 commit comments

Comments
 (0)