Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,15 @@ python_unittest(
"//caffe2:torch",
]
)

python_unittest(
name = "test_quantizer_ops",
srcs = [
"tests/test_quantizer_ops.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:quantizer",
],
)
34 changes: 9 additions & 25 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceDefaultQuantizer,
CadenceQuantizer,
CadenceW8A32MixedQuantizer,
)
from executorch.backends.cadence.aot.utils import (
get_default_memory_config,
Expand All @@ -51,36 +50,17 @@
default_quantizer = CadenceDefaultQuantizer()


# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_pt2 instead.
def trace(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
quantizer: Optional[CadenceQuantizer] = None,
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
) -> ExportedProgram:
"""
Trace the model with export and return an ExportedProgram.
"""

ops_to_keep = [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.layer_norm.default,
torch.ops.aten.linear.default,
torch.ops.aten.matmul.default,
torch.ops.aten.rms_norm.default,
]

if isinstance(quantizer, CadenceW8A32MixedQuantizer):
ops_to_keep += [
torch.ops.aten.gru.input,
torch.ops.aten.gru.data,
]

if ops_to_keep is None:
ops_to_keep = []
program = trace_fn(
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
)
Expand All @@ -107,7 +87,10 @@ def prepare_pt2(
Returns a GraphModule with the prepared model.
"""

traced_program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer)
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
traced_program = trace(
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep
)
prepared_program = prepare_traced_pt2(
traced_program, quantizer, dump_graphs=dump_graphs
)
Expand Down Expand Up @@ -192,7 +175,8 @@ def get_fake_quant_model(
# Make the model inference mode by calling model.eval()
model.eval()

program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer)
ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep)

if dump_graphs:
logging.info("Graph after trace:")
Expand Down
19 changes: 18 additions & 1 deletion backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import final, List, Optional, Tuple, Union

import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
Expand Down Expand Up @@ -244,6 +244,23 @@ class for explicitly defined quantizers (like CadenceDefaultQuantizer).
def __init__(self, quantizers: List[Quantizer]) -> None:
super().__init__(quantizers)

@final
def get_ops_to_preserve_from_decomposition(self) -> List[torch._ops.OpOverload]:
"""
Get complete list of ops to preserve from decomposition.

Delegates preservation choices to QuantizationPattern by aggregating
the pattern's partition_types(), which explicitly declares the root
ops that compose the pattern and should be preserved.
"""
ops: set[torch._ops.OpOverload] = set()
for q in self.quantizers:
if isinstance(q, CadenceAtenQuantizer):
ops.update(q.pattern.partition_types())
elif isinstance(q, CadenceQuantizer):
ops.update(q.get_ops_to_preserve_from_decomposition())
return list(ops)


class CadenceDefaultQuantizer(CadenceQuantizer):
"""
Expand Down
73 changes: 73 additions & 0 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern

from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceAtenQuantizer,
CadenceDefaultQuantizer,
CadenceW8A32MixedQuantizer,
qconfig_A8W8,
)


class QuantizerOpsPreserveTest(unittest.TestCase):
def test_mixed_w8a32_ops_to_preserve(self) -> None:
q = CadenceW8A32MixedQuantizer()
actual = q.get_ops_to_preserve_from_decomposition()
expected = [
torch.ops.aten.linear.default,
torch.ops.aten.conv1d.default,
torch.ops.aten.gru.input,
]
self.assertCountEqual(actual, expected)

def test_default_quantizer_ops_to_preserve(self) -> None:
q = CadenceDefaultQuantizer()
actual = q.get_ops_to_preserve_from_decomposition()
expected = [
torch.ops.aten.addmm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
torch.ops.aten.matmul.default,
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
]
self.assertCountEqual(actual, expected)

def test_nested_quantizer_ops_to_preserve(self) -> None:
# Setup: Create a nested CadenceQuantizer-like structure by composing
# - CadenceW8A32MixedQuantizer (which preserves linear, conv1d, gru.input)
# - A CadenceAtenQuantizer with AddmmPattern (which preserves addmm)
nested = CadenceDefaultQuantizer(
quantizers=[
CadenceW8A32MixedQuantizer(),
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8),
]
)

# Execute
actual = nested.get_ops_to_preserve_from_decomposition()

# Assert: union of both sets without duplicates
expected = [
torch.ops.aten.linear.default,
torch.ops.aten.conv1d.default,
torch.ops.aten.gru.input,
torch.ops.aten.addmm.default,
]
self.assertCountEqual(actual, expected)


if __name__ == "__main__":
unittest.main()
Loading