Skip to content

Commit 6a0833f

Browse files
authored
Choosing ops_to_preserve by delegating to pattern
Differential Revision: D84524714 Pull Request resolved: #15121
1 parent fc7d03b commit 6a0833f

File tree

4 files changed

+112
-26
lines changed

4 files changed

+112
-26
lines changed

backends/cadence/aot/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,3 +631,15 @@ python_unittest(
631631
"//caffe2:torch",
632632
]
633633
)
634+
635+
python_unittest(
636+
name = "test_quantizer_ops",
637+
srcs = [
638+
"tests/test_quantizer_ops.py",
639+
],
640+
typing = True,
641+
deps = [
642+
"//caffe2:torch",
643+
"//executorch/backends/cadence/aot/quantizer:quantizer",
644+
],
645+
)

backends/cadence/aot/compiler.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from executorch.backends.cadence.aot.quantizer.quantizer import (
2525
CadenceDefaultQuantizer,
2626
CadenceQuantizer,
27-
CadenceW8A32MixedQuantizer,
2827
)
2928
from executorch.backends.cadence.aot.utils import (
3029
get_default_memory_config,
@@ -51,36 +50,17 @@
5150
default_quantizer = CadenceDefaultQuantizer()
5251

5352

54-
# Note: this is not meant as a primary API since it can create inconsistencies
55-
# if the quantizer here is different from the quantizer used to convert. It is
56-
# however useful for unit tests to separate the converted model from the fused
57-
# model, to be able to get reference numerics.
58-
# If this does not apply, please use quantize_pt2 instead.
5953
def trace(
6054
model: torch.nn.Module,
6155
inputs: tuple[object, ...],
6256
dump_graphs: bool = False,
63-
quantizer: Optional[CadenceQuantizer] = None,
57+
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
6458
) -> ExportedProgram:
6559
"""
6660
Trace the model with export and return an ExportedProgram.
6761
"""
68-
69-
ops_to_keep = [
70-
torch.ops.aten.conv1d.default,
71-
torch.ops.aten.conv2d.default,
72-
torch.ops.aten.layer_norm.default,
73-
torch.ops.aten.linear.default,
74-
torch.ops.aten.matmul.default,
75-
torch.ops.aten.rms_norm.default,
76-
]
77-
78-
if isinstance(quantizer, CadenceW8A32MixedQuantizer):
79-
ops_to_keep += [
80-
torch.ops.aten.gru.input,
81-
torch.ops.aten.gru.data,
82-
]
83-
62+
if ops_to_keep is None:
63+
ops_to_keep = []
8464
program = trace_fn(
8565
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
8666
)
@@ -107,7 +87,10 @@ def prepare_pt2(
10787
Returns a GraphModule with the prepared model.
10888
"""
10989

110-
traced_program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer)
90+
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
91+
traced_program = trace(
92+
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep
93+
)
11194
prepared_program = prepare_traced_pt2(
11295
traced_program, quantizer, dump_graphs=dump_graphs
11396
)
@@ -192,7 +175,8 @@ def get_fake_quant_model(
192175
# Make the model inference mode by calling model.eval()
193176
model.eval()
194177

195-
program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer)
178+
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
179+
program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep)
196180

197181
if dump_graphs:
198182
logging.info("Graph after trace:")

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from dataclasses import dataclass
10-
from typing import List, Optional, Tuple, Union
10+
from typing import final, List, Optional, Tuple, Union
1111

1212
import torch
1313
from executorch.backends.cadence.aot.quantizer.patterns import (
@@ -244,6 +244,23 @@ class for explicitly defined quantizers (like CadenceDefaultQuantizer).
244244
def __init__(self, quantizers: List[Quantizer]) -> None:
245245
super().__init__(quantizers)
246246

247+
@final
248+
def get_ops_to_preserve_from_decomposition(self) -> List[torch._ops.OpOverload]:
249+
"""
250+
Get complete list of ops to preserve from decomposition.
251+
252+
Delegates preservation choices to QuantizationPattern by aggregating
253+
the pattern's partition_types(), which explicitly declares the root
254+
ops that compose the pattern and should be preserved.
255+
"""
256+
ops: set[torch._ops.OpOverload] = set()
257+
for q in self.quantizers:
258+
if isinstance(q, CadenceAtenQuantizer):
259+
ops.update(q.pattern.partition_types())
260+
elif isinstance(q, CadenceQuantizer):
261+
ops.update(q.get_ops_to_preserve_from_decomposition())
262+
return list(ops)
263+
247264

248265
class CadenceDefaultQuantizer(CadenceQuantizer):
249266
"""
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
13+
14+
from executorch.backends.cadence.aot.quantizer.quantizer import (
15+
CadenceAtenQuantizer,
16+
CadenceDefaultQuantizer,
17+
CadenceW8A32MixedQuantizer,
18+
qconfig_A8W8,
19+
)
20+
21+
22+
class QuantizerOpsPreserveTest(unittest.TestCase):
23+
def test_mixed_w8a32_ops_to_preserve(self) -> None:
24+
q = CadenceW8A32MixedQuantizer()
25+
actual = q.get_ops_to_preserve_from_decomposition()
26+
expected = [
27+
torch.ops.aten.linear.default,
28+
torch.ops.aten.conv1d.default,
29+
torch.ops.aten.gru.input,
30+
]
31+
self.assertCountEqual(actual, expected)
32+
33+
def test_default_quantizer_ops_to_preserve(self) -> None:
34+
q = CadenceDefaultQuantizer()
35+
actual = q.get_ops_to_preserve_from_decomposition()
36+
expected = [
37+
torch.ops.aten.addmm.default,
38+
torch.ops.aten.bmm.default,
39+
torch.ops.aten.conv1d.default,
40+
torch.ops.aten.conv2d.default,
41+
torch.ops.aten.linear.default,
42+
torch.ops.aten.matmul.default,
43+
torch.ops.aten.relu.default,
44+
torch.ops.aten.relu_.default,
45+
]
46+
self.assertCountEqual(actual, expected)
47+
48+
def test_nested_quantizer_ops_to_preserve(self) -> None:
49+
# Setup: Create a nested CadenceQuantizer-like structure by composing
50+
# - CadenceW8A32MixedQuantizer (which preserves linear, conv1d, gru.input)
51+
# - A CadenceAtenQuantizer with AddmmPattern (which preserves addmm)
52+
nested = CadenceDefaultQuantizer(
53+
quantizers=[
54+
CadenceW8A32MixedQuantizer(),
55+
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8),
56+
]
57+
)
58+
59+
# Execute
60+
actual = nested.get_ops_to_preserve_from_decomposition()
61+
62+
# Assert: union of both sets without duplicates
63+
expected = [
64+
torch.ops.aten.linear.default,
65+
torch.ops.aten.conv1d.default,
66+
torch.ops.aten.gru.input,
67+
torch.ops.aten.addmm.default,
68+
]
69+
self.assertCountEqual(actual, expected)
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()

0 commit comments

Comments
 (0)