Skip to content

Commit 9764269

Browse files
authored
Pass which replaces torch quantized embedding byte with cadence variant
Differential Revision: D84109801 Pull Request resolved: pytorch#14906
1 parent 66c3dea commit 9764269

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,6 +2156,52 @@ def call_operator(self, op, args, kwargs, meta):
21562156
)
21572157

21582158

2159+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2160+
class ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding(ExportPass):
2161+
"""
2162+
Replace torch.ops.quantized_decomposed.embedding_byte.dtype with
2163+
torch.ops.cadence.quantized_embedding_byte
2164+
"""
2165+
2166+
def call_operator(
2167+
self,
2168+
op: torch._ops.OpOverload,
2169+
args: Tuple[Argument, ...],
2170+
kwargs: Dict[str, Argument],
2171+
meta: NodeMetadata,
2172+
) -> ProxyValue:
2173+
# Check if the op is the quantized_decomposed.embedding_byte.dtype
2174+
if (
2175+
op == exir_ops.edge.quantized_decomposed.embedding_byte.default
2176+
or op == exir_ops.edge.quantized_decomposed.embedding_byte.dtype
2177+
):
2178+
# Replace with cadence.quantized_embedding_byte
2179+
if len(args) < 6:
2180+
raise AssertionError(
2181+
f"Expected 6 arguments for embedding_byte, got {len(args)}"
2182+
)
2183+
embedding = args[0]
2184+
scales = args[1]
2185+
weight_zero_points = args[2]
2186+
indices = args[5]
2187+
if op == exir_ops.edge.quantized_decomposed.embedding_byte.dtype:
2188+
dtype = kwargs.get("dtype", None)
2189+
if dtype is not None and dtype != torch.float32:
2190+
raise AssertionError(
2191+
f"Unsupported output dtype for embedding_byte: {dtype}"
2192+
)
2193+
2194+
new_args = (embedding, scales, weight_zero_points, indices, False)
2195+
new_kwargs = {}
2196+
return super().call_operator(
2197+
exir_ops.edge.cadence.quantized_embedding_byte.default,
2198+
new_args,
2199+
new_kwargs,
2200+
meta,
2201+
)
2202+
return super().call_operator(op, args, kwargs, meta)
2203+
2204+
21592205
class CommonReplacePasses:
21602206
passes = [
21612207
ReplaceSqueezeAndUnsqueezeWithViewPass,
@@ -2168,6 +2214,7 @@ class CommonReplacePasses:
21682214
ReplacePT2QuantWithCadenceQuantPass,
21692215
ReplacePT2DequantWithCadenceDequantPass,
21702216
ReplacePowWithMulPass,
2217+
ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding,
21712218
]
21722219

21732220

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
4646
ReplaceSplitWithSlicePass,
4747
ReplaceSqueezeAndUnsqueezeWithViewPass,
48+
ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding,
4849
ReplaceTransposedConvWithLinearPass,
4950
ReplaceTrivialConvWithLinear,
5051
ReplaceWhereWithFullArgsWithWhereScalar,
@@ -2269,3 +2270,48 @@ def test_replace_aten_linalg_svd_with_cadence_linalg_svd(
22692270
count_node(graph_after_passes, exir_ops.edge.cadence.linalg_svd.default),
22702271
1,
22712272
)
2273+
2274+
@expand([("dtype",), ("default",)])
2275+
@torch.no_grad()
2276+
def test_replace_quantized_embedding(
2277+
self,
2278+
name: str,
2279+
) -> None:
2280+
embedding = torch.ones(5, 6, dtype=torch.int8)
2281+
indices = torch.tensor([0, 2], dtype=torch.int32)
2282+
scales = torch.ones(5, 2, dtype=torch.float32)
2283+
zero_points = None
2284+
2285+
original_gm = single_op_builder(
2286+
placeholders=(embedding, scales, indices),
2287+
op=(
2288+
exir_ops.edge.quantized_decomposed.embedding_byte.dtype
2289+
if name == "dtype"
2290+
else exir_ops.edge.quantized_decomposed.embedding_byte.default
2291+
),
2292+
args=(embedding, scales, zero_points, -128, 127, indices),
2293+
kwargs={"dtype": torch.float32} if name == "dtype" else {},
2294+
)
2295+
2296+
p = ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding()
2297+
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
2298+
2299+
self.assertEqual(
2300+
count_node(
2301+
graph_after_passes,
2302+
(
2303+
exir_ops.edge.quantized_decomposed.embedding_byte.dtype
2304+
if name == "dtype"
2305+
else exir_ops.edge.quantized_decomposed.embedding_byte.default
2306+
),
2307+
),
2308+
0,
2309+
)
2310+
2311+
self.assertEqual(
2312+
count_node(
2313+
graph_after_passes,
2314+
exir_ops.edge.cadence.quantized_embedding_byte.default,
2315+
),
2316+
1,
2317+
)

0 commit comments

Comments
 (0)