Skip to content

Commit 50b366d

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Update remove clone to drop no-op q/dq (#10920)
Summary: After removing clone, we may be left with no-op quantize operations. This diff updates the pass in backend/transforms to remove these, if they exist Differential Revision: D74832417
1 parent c64a7fd commit 50b366d

File tree

3 files changed

+172
-12
lines changed

3 files changed

+172
-12
lines changed

backends/transforms/remove_clone_ops.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,40 @@
99
import torch
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass, PassResult
12+
from executorch.exir.passes import dead_code_elimination_pass
13+
from executorch.exir.passes.remove_noop_pass import _DEQUANT_OPS, eliminate_dq_q
1214

1315

14-
def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
16+
class RemoveCloneOpsTransform(ExportPass):
1517
"""
16-
Remove clone op nodes and replace uses with parent node.
18+
Trim the 'identity' operators to reduce the unnecessary copy overhead.
1719
"""
18-
clone_op = exir_ops.edge.aten.clone.default
19-
for node in graph.nodes:
20-
if node.op == "call_function" and node.target == clone_op:
21-
with graph.inserting_after(node):
22-
node.replace_all_uses_with(node.args[0])
2320

24-
graph.eliminate_dead_code()
25-
return graph
21+
clone_ops = {
22+
exir_ops.edge.aten.clone.default,
23+
}
2624

25+
def __init__(self):
26+
super().__init__()
2727

28-
class RemoveCloneOpsTransform(ExportPass):
29-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
30-
graph_module.graph = remove_clone_ops(graph_module.graph)
28+
def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
29+
dequant_nodes = []
30+
31+
for n in graph_module.graph.nodes:
32+
if n.target not in self.clone_ops:
33+
continue
34+
35+
to_be_remove = n
36+
for user_n in list(n.users.keys()):
37+
user_n.replace_input_with(n, n.args[0])
38+
if n.args[0].target in _DEQUANT_OPS:
39+
dequant_nodes += [n.args[0]]
40+
graph_module.graph.erase_node(to_be_remove)
41+
42+
eliminate_dq_q(graph_module, dequant_nodes)
43+
44+
def call(self, graph_module: torch.fx.GraphModule):
45+
self._remove(graph_module)
46+
graph_module.recompile()
47+
dead_code_elimination_pass(graph_module)
3148
return PassResult(graph_module, True)

backends/transforms/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def define_common_targets():
109109
srcs = ["remove_clone_ops.py"],
110110
visibility = [
111111
"//executorch/backends/...",
112+
"@EXECUTORCH_CLIENTS",
112113
],
113114
deps = [
114115
"//caffe2:torch",
@@ -242,3 +243,15 @@ def define_common_targets():
242243
":rank_0_to_rank_1",
243244
],
244245
)
246+
247+
runtime.python_test(
248+
name = "test_remove_clone_ops",
249+
srcs = [
250+
"test/test_remove_clone_ops.py",
251+
],
252+
deps = [
253+
"//caffe2:torch",
254+
"//executorch/exir:lib",
255+
":remove_clone_ops",
256+
],
257+
)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
11+
from executorch.exir import EdgeCompileConfig, to_edge
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.export import export
14+
from torch.fx import GraphModule
15+
from torch.testing import FileCheck
16+
from torch.testing._internal.common_utils import TestCase
17+
18+
19+
class TestRemoveCloneOpsTransform(TestCase):
20+
def test_dq_clone_q_linear(self):
21+
"""
22+
Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
23+
24+
Before: Should contain all nodes
25+
After: Should only have the linear operation
26+
"""
27+
28+
# Create a graph module directly with the pattern: quant -> clone -> dequant -> fp linear
29+
class TestModule(torch.nn.Module):
30+
def __init__(self):
31+
super().__init__()
32+
self.linear = torch.nn.Linear(10, 5)
33+
34+
def forward(self, x):
35+
# This will be replaced with our custom graph
36+
return self.linear(x)
37+
38+
# Create a module instance
39+
module = TestModule()
40+
41+
# Create a new graph with our desired pattern
42+
graph = torch.fx.Graph()
43+
44+
# Add placeholders
45+
input_node = graph.placeholder("x")
46+
47+
# Create nodes for our pattern: quant -> clone -> dequant -> fp linear
48+
# Constants for quantization parameters
49+
scale = graph.create_node(
50+
"call_function", torch.tensor, args=([0.1],), kwargs={}
51+
)
52+
zero_point = graph.create_node(
53+
"call_function", torch.tensor, args=([0],), kwargs={}
54+
)
55+
56+
# Dequantize node
57+
dequant_node = graph.create_node(
58+
"call_function",
59+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
60+
args=(input_node, scale, zero_point, torch.int8),
61+
kwargs={},
62+
)
63+
64+
# Clone node.
65+
# Use Edge op as this is an executorch pass
66+
clone_node = graph.create_node(
67+
"call_function",
68+
exir_ops.edge.aten.clone.default,
69+
args=(dequant_node,),
70+
kwargs={},
71+
)
72+
73+
# Quantize node
74+
quant_node = graph.create_node(
75+
"call_function",
76+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
77+
args=(clone_node, scale, zero_point, torch.int8),
78+
kwargs={},
79+
)
80+
81+
# Linear node (using the module's linear layer)
82+
# Technically, should use quantized weight and bias
83+
# but we are just inspecting graph patterns in this test
84+
weight = graph.create_node("get_attr", "linear.weight")
85+
bias = graph.create_node("get_attr", "linear.bias")
86+
linear_node = graph.create_node(
87+
"call_function",
88+
torch.nn.functional.linear,
89+
args=(quant_node, weight, bias),
90+
kwargs={},
91+
)
92+
93+
# Output
94+
graph.output(linear_node)
95+
96+
# Create a GraphModule with our custom graph
97+
gm = GraphModule(module, graph)
98+
99+
# Verify we have the expected nodes before transformation using FileCheck
100+
FileCheck().check(
101+
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
102+
).check(
103+
"executorch_exir_dialects_edge__ops_aten_clone_default",
104+
).check(
105+
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
106+
).check(
107+
"torch._C._nn.linear",
108+
).run(
109+
gm.code
110+
)
111+
112+
# Apply the transform
113+
transformed_gm = RemoveCloneOpsTransform()(gm).graph_module
114+
115+
# Verify the dq -> clone -> q pattern is removed and linear op is still present using FileCheck
116+
FileCheck().check_not(
117+
"executorch_exir_dialects_edge__ops_aten_clone_default"
118+
).check_not("quantized_decomposed.dequantize_per_tensor.default").check_not(
119+
"quantized_decomposed.quantize_per_tensor.default"
120+
).check_count(
121+
"torch._C._nn.linear",
122+
1,
123+
exactly=True,
124+
).run(
125+
transformed_gm.code
126+
)
127+
128+
129+
if __name__ == "__main__":
130+
unittest.main()

0 commit comments

Comments
 (0)