From a1ede0e256c4e38d274d8e111374c2e84f8c0f7e Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 9 Oct 2025 14:19:06 -0700 Subject: [PATCH] Pass which replaces torch quantized embedding byte with cadence variant (#14906) Summary: As titled Reviewed By: mcremon-meta, zonglinpeng Differential Revision: D84109801 --- backends/cadence/aot/replace_ops.py | 47 +++++++++++++++++++ .../aot/tests/test_replace_ops_passes.py | 46 ++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 7025159e443..3cfc059e75b 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2156,6 +2156,52 @@ def call_operator(self, op, args, kwargs, meta): ) +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding(ExportPass): + """ + Replace torch.ops.quantized_decomposed.embedding_byte.dtype with + torch.ops.cadence.quantized_embedding_byte + """ + + def call_operator( + self, + op: torch._ops.OpOverload, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + # Check if the op is the quantized_decomposed.embedding_byte.dtype + if ( + op == exir_ops.edge.quantized_decomposed.embedding_byte.default + or op == exir_ops.edge.quantized_decomposed.embedding_byte.dtype + ): + # Replace with cadence.quantized_embedding_byte + if len(args) < 6: + raise AssertionError( + f"Expected 6 arguments for embedding_byte, got {len(args)}" + ) + embedding = args[0] + scales = args[1] + weight_zero_points = args[2] + indices = args[5] + if op == exir_ops.edge.quantized_decomposed.embedding_byte.dtype: + dtype = kwargs.get("dtype", None) + if dtype is not None and dtype != torch.float32: + raise AssertionError( + f"Unsupported output dtype for embedding_byte: {dtype}" + ) + + new_args = (embedding, scales, weight_zero_points, indices, False) + new_kwargs = {} + return super().call_operator( + exir_ops.edge.cadence.quantized_embedding_byte.default, + new_args, + new_kwargs, + meta, + ) + return super().call_operator(op, args, kwargs, meta) + + class CommonReplacePasses: passes = [ ReplaceSqueezeAndUnsqueezeWithViewPass, @@ -2168,6 +2214,7 @@ class CommonReplacePasses: ReplacePT2QuantWithCadenceQuantPass, ReplacePT2DequantWithCadenceDequantPass, ReplacePowWithMulPass, + ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index c15755f58c5..e2fbd516757 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -45,6 +45,7 @@ ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceSplitWithSlicePass, ReplaceSqueezeAndUnsqueezeWithViewPass, + ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding, ReplaceTransposedConvWithLinearPass, ReplaceTrivialConvWithLinear, ReplaceWhereWithFullArgsWithWhereScalar, @@ -2269,3 +2270,48 @@ def test_replace_aten_linalg_svd_with_cadence_linalg_svd( count_node(graph_after_passes, exir_ops.edge.cadence.linalg_svd.default), 1, ) + + @expand([("dtype",), ("default",)]) + @torch.no_grad() + def test_replace_quantized_embedding( + self, + name: str, + ) -> None: + embedding = torch.ones(5, 6, dtype=torch.int8) + indices = torch.tensor([0, 2], dtype=torch.int32) + scales = torch.ones(5, 2, dtype=torch.float32) + zero_points = None + + original_gm = single_op_builder( + placeholders=(embedding, scales, indices), + op=( + exir_ops.edge.quantized_decomposed.embedding_byte.dtype + if name == "dtype" + else exir_ops.edge.quantized_decomposed.embedding_byte.default + ), + args=(embedding, scales, zero_points, -128, 127, indices), + kwargs={"dtype": torch.float32} if name == "dtype" else {}, + ) + + p = ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding() + graph_after_passes = cast(PassResult, p(original_gm)).graph_module + + self.assertEqual( + count_node( + graph_after_passes, + ( + exir_ops.edge.quantized_decomposed.embedding_byte.dtype + if name == "dtype" + else exir_ops.edge.quantized_decomposed.embedding_byte.default + ), + ), + 0, + ) + + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.cadence.quantized_embedding_byte.default, + ), + 1, + )