Skip to content

Commit 9156fff

Browse files
authored
Add RmsNormNopQuantizer and Pattern
Differential Revision: D88520820 Pull Request resolved: #16117
1 parent 8511b30 commit 9156fff

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

backends/cadence/aot/quantizer/patterns.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,3 +721,18 @@ def __init__(self, args, meta):
721721

722722
def replacement_op(self) -> OpOverload:
723723
return torch.ops.cadence.quantized_w8a32_gru.default
724+
725+
726+
class RmsNormPattern(QuantizationPattern):
727+
"""Pattern that preserves rms_norm from decomposition without matching anything."""
728+
729+
def partition_types(self) -> list[torch._ops.OpOverload]:
730+
return [torch.ops.aten.rms_norm.default]
731+
732+
def get_anchors(
733+
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
734+
) -> Tuple[PartitionAnchors, fx.Node]:
735+
return PartitionAnchors(empty=True), None # pyre-ignore[7]
736+
737+
def replacement_op(self) -> torch._ops.OpOverload:
738+
return torch.ops.aten.rms_norm.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,15 @@
3030
QuantizationPattern,
3131
ReluPattern0,
3232
ReluPattern1,
33+
RmsNormPattern,
3334
SoftmaxPattern,
3435
)
3536
from executorch.backends.cadence.aot.quantizer.utils import (
3637
find_sequential_partitions_aten,
3738
is_annotated,
3839
no_outside_users,
3940
)
40-
4141
from torch import fx
42-
4342
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
4443
from torchao.quantization.pt2e.quantizer import (
4544
ComposableQuantizer,
@@ -285,6 +284,15 @@ def __init__(
285284
super().__init__([])
286285

287286

287+
class CadenceRmsNormNopQuantizer(CadenceQuantizer):
288+
"""
289+
Nop quantizer that preserves rms_norm from decomposition.
290+
"""
291+
292+
def __init__(self) -> None:
293+
super().__init__([CadenceAtenQuantizer(RmsNormPattern(), qconfig_A8W8)])
294+
295+
288296
class CadenceWithLayerNormQuantizer(CadenceQuantizer):
289297
"""
290298
Quantizer including layer norm

0 commit comments

Comments
 (0)