Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,15 @@ python_unittest(
"//caffe2:torch",
]
)

python_unittest(
name = "test_quantizer_ops",
srcs = [
"tests/test_quantizer_ops.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:quantizer",
],
)
18 changes: 3 additions & 15 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,10 @@ def trace(
"""
Trace the model with export and return an ExportedProgram.
"""
if quantizer is None:
quantizer = default_quantizer

ops_to_keep = [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.layer_norm.default,
torch.ops.aten.linear.default,
torch.ops.aten.matmul.default,
torch.ops.aten.rms_norm.default,
]

if isinstance(quantizer, CadenceW8A32MixedQuantizer):
ops_to_keep += [
torch.ops.aten.gru.input,
torch.ops.aten.gru.data,
]

ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
program = trace_fn(
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
)
Expand Down
52 changes: 51 additions & 1 deletion backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# pyre-strict

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, final


import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
Expand Down Expand Up @@ -233,6 +234,17 @@ def get_cadence_default_quantizers() -> List[Quantizer]:
]


def get_cadence_default_ops() -> List[torch._ops.OpOverload]:
return [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.layer_norm.default,
torch.ops.aten.linear.default,
torch.ops.aten.matmul.default,
torch.ops.aten.rms_norm.default,
]


# Note: need dataclass to be used in CI configs through OmegaConf and Hydra
@dataclass
class CadenceQuantizer(ComposableQuantizer):
Expand All @@ -244,6 +256,38 @@ class for explicitly defined quantizers (like CadenceDefaultQuantizer).
def __init__(self, quantizers: List[Quantizer]) -> None:
super().__init__(quantizers)

# Class-level additive configuration: subclasses may contribute ops via this attribute
ADDITIONAL_OPS_TO_PRESERVE: Tuple[torch._ops.OpOverload, ...] = ()

@classmethod
def _collect_additional_ops(cls) -> List[torch._ops.OpOverload]:
"""
Union all ADDITIONAL_OPS_TO_PRESERVE across the class hierarchy (MRO).
Ensures additive inheritance.
"""
ops: set[torch._ops.OpOverload] = set()
for klass in cls.__mro__:
attr = getattr(klass, "ADDITIONAL_OPS_TO_PRESERVE", ())
# Support tuple/list definitions
ops.update(attr)
return list(ops)

@final
def get_ops_to_preserve_from_decomposition(self) -> List[torch._ops.OpOverload]:
"""
Get complete list of ops to preserve from decomposition.

Combines base Cadence ops with quantizer-specific additional ops aggregated
across the inheritance chain.

Returns:
Deduplicated list of all ops to preserve
"""
base_ops = get_cadence_default_ops()

additional_ops = type(self)._collect_additional_ops()
return list(set(base_ops) | set(additional_ops))


class CadenceDefaultQuantizer(CadenceQuantizer):
"""
Expand Down Expand Up @@ -331,6 +375,12 @@ def __init__(self) -> None:
)
super().__init__(quantizers)

# Additional ops contributed by this quantizer
ADDITIONAL_OPS_TO_PRESERVE: Tuple[torch._ops.OpOverload, ...] = (
torch.ops.aten.gru.input,
torch.ops.aten.gru.data,
)


class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
"""
Expand Down
66 changes: 66 additions & 0 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch

from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceW8A32MixedQuantizer,
CadenceDefaultQuantizer,
get_cadence_default_ops,
)


class DerivedMixedQuantizer(CadenceW8A32MixedQuantizer):
"""
Test-only subclass to validate MRO aggregation:
contributes one additional op beyond CadenceW8A32MixedQuantizer.
"""

ADDITIONAL_OPS_TO_PRESERVE: tuple[torch._ops.OpOverload, ...] = (
torch.ops.aten.batch_norm.default,
)


class QuantizerOpsPreserveTest(unittest.TestCase):
def test_mixed_w8a32_ops_to_preserve(self) -> None:
q = CadenceW8A32MixedQuantizer()
actual = q.get_ops_to_preserve_from_decomposition()
expected = get_cadence_default_ops()
expected += [
torch.ops.aten.gru.input,
torch.ops.aten.gru.data,
]
self.assertCountEqual(actual, expected)

def test_default_quantizer_ops_to_preserve(self) -> None:
q = CadenceDefaultQuantizer()
actual = q.get_ops_to_preserve_from_decomposition()
expected = get_cadence_default_ops()
self.assertCountEqual(actual, expected)

def test_mro_aggregation_includes_subclass_ops(self) -> None:
"""
Validate MRO aggregation: DerivedMixedQuantizer should include
base Cadence ops, GRU ops from CadenceW8A32MixedQuantizer, and
the subclass-contributed batch_norm op.
"""
q = DerivedMixedQuantizer()
actual = q.get_ops_to_preserve_from_decomposition()
expected = get_cadence_default_ops()
expected += [
torch.ops.aten.gru.input,
torch.ops.aten.gru.data,
torch.ops.aten.batch_norm.default,
]
self.assertCountEqual(actual, expected)


if __name__ == "__main__":
unittest.main()
Loading