Skip to content

Commit ba8a474

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Update remove clone to drop no-op q/dq (#10920)
Summary: Pull Request resolved: #10920 After removing clone, we may be left with no-op quantize operations. This diff updates the pass in backend/transforms to remove these, if they exist Differential Revision: D74832417
1 parent 2240ee1 commit ba8a474

File tree

3 files changed

+195
-12
lines changed

3 files changed

+195
-12
lines changed

backends/transforms/remove_clone_ops.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,42 @@
99
import torch
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass, PassResult
12+
from executorch.exir.passes import dead_code_elimination_pass
13+
from executorch.exir.passes.remove_noop_pass import _DEQUANT_OPS, eliminate_dq_q
1214

1315

14-
def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
16+
class RemoveCloneOpsTransform(ExportPass):
1517
"""
16-
Remove clone op nodes and replace uses with parent node.
18+
Trim the 'identity' operators to reduce the unnecessary copy overhead.
1719
"""
18-
clone_op = exir_ops.edge.aten.clone.default
19-
for node in graph.nodes:
20-
if node.op == "call_function" and node.target == clone_op:
21-
with graph.inserting_after(node):
22-
node.replace_all_uses_with(node.args[0])
2320

24-
graph.eliminate_dead_code()
25-
return graph
21+
clone_ops = {
22+
torch.clone,
23+
torch.ops.aten.clone.default,
24+
exir_ops.edge.aten.clone.default,
25+
}
2626

27+
def __init__(self):
28+
super().__init__()
2729

28-
class RemoveCloneOpsTransform(ExportPass):
29-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
30-
graph_module.graph = remove_clone_ops(graph_module.graph)
30+
def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
31+
dequant_nodes = []
32+
33+
for n in graph_module.graph.nodes:
34+
if n.target not in self.clone_ops:
35+
continue
36+
37+
to_be_remove = n
38+
for user_n in list(n.users.keys()):
39+
user_n.replace_input_with(n, n.args[0])
40+
if n.args[0].target in _DEQUANT_OPS:
41+
dequant_nodes += [n.args[0]]
42+
graph_module.graph.erase_node(to_be_remove)
43+
44+
eliminate_dq_q(graph_module, dequant_nodes)
45+
46+
def call(self, graph_module: torch.fx.GraphModule):
47+
self._remove(graph_module)
48+
graph_module.recompile()
49+
dead_code_elimination_pass(graph_module)
3150
return PassResult(graph_module, True)

backends/transforms/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def define_common_targets():
109109
srcs = ["remove_clone_ops.py"],
110110
visibility = [
111111
"//executorch/backends/...",
112+
"@EXECUTORCH_CLIENTS",
112113
],
113114
deps = [
114115
"//caffe2:torch",
@@ -242,3 +243,15 @@ def define_common_targets():
242243
":rank_0_to_rank_1",
243244
],
244245
)
246+
247+
runtime.python_test(
248+
name = "test_remove_clone_ops",
249+
srcs = [
250+
"test/test_remove_clone_ops.py",
251+
],
252+
deps = [
253+
"//caffe2:torch",
254+
"//executorch/exir:lib",
255+
":remove_clone_ops",
256+
],
257+
)
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
11+
from torch.fx import GraphModule
12+
from torch.testing._internal.common_utils import TestCase
13+
14+
15+
class TestRemoveCloneOpsTransform(TestCase):
16+
def test_dq_clone_q_linear(self):
17+
"""
18+
Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
19+
20+
Before: Should contain all nodes
21+
After: Should only have the linear operation
22+
"""
23+
24+
# Create a graph module directly with the pattern: quant -> clone -> dequant -> fp linear
25+
class TestModule(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.linear = torch.nn.Linear(10, 5)
29+
30+
def forward(self, x):
31+
# This will be replaced with our custom graph
32+
return self.linear(x)
33+
34+
# Create a module instance
35+
module = TestModule()
36+
37+
# Create a new graph with our desired pattern
38+
graph = torch.fx.Graph()
39+
40+
# Add placeholders
41+
input_node = graph.placeholder("x")
42+
43+
# Create nodes for our pattern: quant -> clone -> dequant -> fp linear
44+
# Constants for quantization parameters
45+
scale = graph.create_node(
46+
"call_function", torch.tensor, args=([0.1],), kwargs={}
47+
)
48+
zero_point = graph.create_node(
49+
"call_function", torch.tensor, args=([0],), kwargs={}
50+
)
51+
52+
# Dequantize node
53+
dequant_node = graph.create_node(
54+
"call_function",
55+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
56+
args=(input_node, scale, zero_point, torch.int8),
57+
kwargs={},
58+
)
59+
60+
# Clone node
61+
clone_node = graph.create_node(
62+
"call_function",
63+
torch.ops.aten.clone.default,
64+
args=(dequant_node,),
65+
kwargs={},
66+
)
67+
68+
# Quantize node
69+
quant_node = graph.create_node(
70+
"call_function",
71+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
72+
args=(clone_node, scale, zero_point, torch.int8),
73+
kwargs={},
74+
)
75+
76+
# Linear node (using the module's linear layer)
77+
# Technically, should use quantized weight and bias
78+
# but we are just inspecting graph patterns in this test
79+
weight = graph.create_node("get_attr", "linear.weight")
80+
bias = graph.create_node("get_attr", "linear.bias")
81+
linear_node = graph.create_node(
82+
"call_function",
83+
torch.nn.functional.linear,
84+
args=(quant_node, weight, bias),
85+
kwargs={},
86+
)
87+
88+
# Output
89+
graph.output(linear_node)
90+
91+
# Create a GraphModule with our custom graph
92+
gm = GraphModule(module, graph)
93+
94+
# Print the graph before transformation
95+
print("Before transformation:")
96+
print(gm.graph)
97+
98+
# Check node counts before transformation
99+
node_counts_before = {}
100+
for node in gm.graph.nodes:
101+
if node.op == "call_function":
102+
target_name = str(node.target)
103+
if target_name not in node_counts_before:
104+
node_counts_before[target_name] = 0
105+
node_counts_before[target_name] += 1
106+
107+
# Verify we have the expected nodes before transformation
108+
self.assertIn(str(torch.ops.aten.clone.default), node_counts_before)
109+
self.assertIn(
110+
str(torch.ops.quantized_decomposed.quantize_per_tensor.default),
111+
node_counts_before,
112+
)
113+
self.assertIn(
114+
str(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
115+
node_counts_before,
116+
)
117+
self.assertIn(str(torch.nn.functional.linear), node_counts_before)
118+
119+
# Apply the transform
120+
transformed_gm = RemoveCloneOpsTransform()(gm).graph_module
121+
122+
# Print the graph after transformation
123+
print("After transformation:")
124+
print(transformed_gm.graph)
125+
126+
# Check node counts after transformation
127+
node_counts_after = {}
128+
for node in transformed_gm.graph.nodes:
129+
if node.op == "call_function":
130+
target_name = str(node.target)
131+
if target_name not in node_counts_after:
132+
node_counts_after[target_name] = 0
133+
node_counts_after[target_name] += 1
134+
135+
# Verify the dq -> clone -> q pattern is removed
136+
self.assertNotIn(str(torch.ops.aten.clone.default), node_counts_after)
137+
self.assertNotIn(
138+
str(torch.ops.quantized_decomposed.dequantize_per_tensor.default),
139+
node_counts_after,
140+
)
141+
self.assertNotIn(
142+
str(torch.ops.quantized_decomposed.quantize_per_tensor.default),
143+
node_counts_after,
144+
)
145+
146+
# Verify the linear op is still present
147+
self.assertIn(str(torch.nn.functional.linear), node_counts_after)
148+
149+
150+
if __name__ == "__main__":
151+
unittest.main()

0 commit comments

Comments
 (0)