diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index b0e7101c9d2..527bd6646a1 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -631,3 +631,15 @@ python_unittest( "//caffe2:torch", ] ) + +python_unittest( + name = "test_quantizer_ops", + srcs = [ + "tests/test_quantizer_ops.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/backends/cadence/aot/quantizer:quantizer", + ], +) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 91abd94b89f..9c454f4339f 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -24,7 +24,6 @@ from executorch.backends.cadence.aot.quantizer.quantizer import ( CadenceDefaultQuantizer, CadenceQuantizer, - CadenceW8A32MixedQuantizer, ) from executorch.backends.cadence.aot.utils import ( get_default_memory_config, @@ -51,36 +50,17 @@ default_quantizer = CadenceDefaultQuantizer() -# Note: this is not meant as a primary API since it can create inconsistencies -# if the quantizer here is different from the quantizer used to convert. It is -# however useful for unit tests to separate the converted model from the fused -# model, to be able to get reference numerics. -# If this does not apply, please use quantize_pt2 instead. def trace( model: torch.nn.Module, inputs: tuple[object, ...], dump_graphs: bool = False, - quantizer: Optional[CadenceQuantizer] = None, + ops_to_keep: Optional[list[torch._ops.OpOverload]] = None, ) -> ExportedProgram: """ Trace the model with export and return an ExportedProgram. """ - - ops_to_keep = [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - torch.ops.aten.layer_norm.default, - torch.ops.aten.linear.default, - torch.ops.aten.matmul.default, - torch.ops.aten.rms_norm.default, - ] - - if isinstance(quantizer, CadenceW8A32MixedQuantizer): - ops_to_keep += [ - torch.ops.aten.gru.input, - torch.ops.aten.gru.data, - ] - + if ops_to_keep is None: + ops_to_keep = [] program = trace_fn( model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep ) @@ -107,7 +87,10 @@ def prepare_pt2( Returns a GraphModule with the prepared model. """ - traced_program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer) + ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition() + traced_program = trace( + model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep + ) prepared_program = prepare_traced_pt2( traced_program, quantizer, dump_graphs=dump_graphs ) @@ -192,7 +175,8 @@ def get_fake_quant_model( # Make the model inference mode by calling model.eval() model.eval() - program = trace(model, inputs, dump_graphs=dump_graphs, quantizer=quantizer) + ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition() + program = trace(model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep) if dump_graphs: logging.info("Graph after trace:") diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 786b7d6cdf2..70b16b86fda 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -7,7 +7,7 @@ # pyre-strict from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import final, List, Optional, Tuple, Union import torch from executorch.backends.cadence.aot.quantizer.patterns import ( @@ -244,6 +244,23 @@ class for explicitly defined quantizers (like CadenceDefaultQuantizer). def __init__(self, quantizers: List[Quantizer]) -> None: super().__init__(quantizers) + @final + def get_ops_to_preserve_from_decomposition(self) -> List[torch._ops.OpOverload]: + """ + Get complete list of ops to preserve from decomposition. + + Delegates preservation choices to QuantizationPattern by aggregating + the pattern's partition_types(), which explicitly declares the root + ops that compose the pattern and should be preserved. + """ + ops: set[torch._ops.OpOverload] = set() + for q in self.quantizers: + if isinstance(q, CadenceAtenQuantizer): + ops.update(q.pattern.partition_types()) + elif isinstance(q, CadenceQuantizer): + ops.update(q.get_ops_to_preserve_from_decomposition()) + return list(ops) + class CadenceDefaultQuantizer(CadenceQuantizer): """ diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py new file mode 100644 index 00000000000..f0df592558f --- /dev/null +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern + +from executorch.backends.cadence.aot.quantizer.quantizer import ( + CadenceAtenQuantizer, + CadenceDefaultQuantizer, + CadenceW8A32MixedQuantizer, + qconfig_A8W8, +) + + +class QuantizerOpsPreserveTest(unittest.TestCase): + def test_mixed_w8a32_ops_to_preserve(self) -> None: + q = CadenceW8A32MixedQuantizer() + actual = q.get_ops_to_preserve_from_decomposition() + expected = [ + torch.ops.aten.linear.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.gru.input, + ] + self.assertCountEqual(actual, expected) + + def test_default_quantizer_ops_to_preserve(self) -> None: + q = CadenceDefaultQuantizer() + actual = q.get_ops_to_preserve_from_decomposition() + expected = [ + torch.ops.aten.addmm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ] + self.assertCountEqual(actual, expected) + + def test_nested_quantizer_ops_to_preserve(self) -> None: + # Setup: Create a nested CadenceQuantizer-like structure by composing + # - CadenceW8A32MixedQuantizer (which preserves linear, conv1d, gru.input) + # - A CadenceAtenQuantizer with AddmmPattern (which preserves addmm) + nested = CadenceDefaultQuantizer( + quantizers=[ + CadenceW8A32MixedQuantizer(), + CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8), + ] + ) + + # Execute + actual = nested.get_ops_to_preserve_from_decomposition() + + # Assert: union of both sets without duplicates + expected = [ + torch.ops.aten.linear.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.gru.input, + torch.ops.aten.addmm.default, + ] + self.assertCountEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main()