Skip to content

Commit c1f05f3

Browse files
Merge branch 'main' into improve-vgf-runtime-and-update-mlsdk-url
2 parents a22813f + 414fc32 commit c1f05f3

File tree

4 files changed

+183
-16
lines changed

4 files changed

+183
-16
lines changed

backends/transforms/remove_clone_ops.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,45 @@
66

77
# pyre-strict
88

9+
from typing import Set
10+
911
import torch
1012
from executorch.exir.dialects._ops import ops as exir_ops
1113
from executorch.exir.pass_base import ExportPass, PassResult
14+
from executorch.exir.passes import dead_code_elimination_pass
15+
from executorch.exir.passes.remove_noop_pass import _DEQUANT_OPS, eliminate_dq_q
1216

1317

14-
def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
18+
class RemoveCloneOpsTransform(ExportPass):
1519
"""
16-
Remove clone op nodes and replace uses with parent node.
20+
Trim the 'identity' operators to reduce the unnecessary copy overhead.
1721
"""
18-
clone_op = exir_ops.edge.aten.clone.default
19-
for node in graph.nodes:
20-
if node.op == "call_function" and node.target == clone_op:
21-
with graph.inserting_after(node):
22-
node.replace_all_uses_with(node.args[0])
2322

24-
graph.eliminate_dead_code()
25-
return graph
23+
clone_ops: Set[torch._ops.OpOverload] = {
24+
exir_ops.edge.aten.clone.default,
25+
}
2626

27+
def __init__(self) -> None:
28+
super().__init__()
29+
30+
def _remove(self, graph_module: torch.fx.GraphModule) -> None:
31+
dequant_nodes = []
32+
33+
for n in graph_module.graph.nodes:
34+
if n.target not in self.clone_ops:
35+
continue
36+
37+
to_be_remove = n
38+
for user_n in list(n.users.keys()):
39+
user_n.replace_input_with(n, n.args[0])
40+
if n.args[0].target in _DEQUANT_OPS:
41+
dequant_nodes += [n.args[0]]
42+
graph_module.graph.erase_node(to_be_remove)
43+
44+
eliminate_dq_q(graph_module, dequant_nodes)
2745

28-
class RemoveCloneOpsTransform(ExportPass):
2946
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
30-
graph_module.graph = remove_clone_ops(graph_module.graph)
47+
self._remove(graph_module)
48+
graph_module.recompile()
49+
dead_code_elimination_pass(graph_module)
3150
return PassResult(graph_module, True)

backends/transforms/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def define_common_targets():
109109
srcs = ["remove_clone_ops.py"],
110110
visibility = [
111111
"//executorch/backends/...",
112+
"@EXECUTORCH_CLIENTS",
112113
],
113114
deps = [
114115
"//caffe2:torch",
@@ -242,3 +243,15 @@ def define_common_targets():
242243
":rank_0_to_rank_1",
243244
],
244245
)
246+
247+
runtime.python_test(
248+
name = "test_remove_clone_ops",
249+
srcs = [
250+
"test/test_remove_clone_ops.py",
251+
],
252+
deps = [
253+
"//caffe2:torch",
254+
"//executorch/exir:lib",
255+
":remove_clone_ops",
256+
],
257+
)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from torch.fx import GraphModule
13+
from torch.testing import FileCheck
14+
from torch.testing._internal.common_utils import TestCase
15+
16+
17+
class TestRemoveCloneOpsTransform(TestCase):
18+
def test_dq_clone_q_linear(self):
19+
"""
20+
Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
21+
22+
Before: Should contain all nodes
23+
After: Should only have the linear operation
24+
"""
25+
26+
# Create a graph module directly with the pattern: quant -> clone -> dequant -> fp linear
27+
class TestModule(torch.nn.Module):
28+
def __init__(self):
29+
super().__init__()
30+
self.linear = torch.nn.Linear(10, 5)
31+
32+
def forward(self, x):
33+
# This will be replaced with our custom graph
34+
return self.linear(x)
35+
36+
# Create a module instance
37+
module = TestModule()
38+
39+
# Create a new graph with our desired pattern
40+
graph = torch.fx.Graph()
41+
42+
# Add placeholders
43+
input_node = graph.placeholder("x")
44+
45+
# Create nodes for our pattern: quant -> clone -> dequant -> fp linear
46+
# Constants for quantization parameters
47+
scale = graph.create_node(
48+
"call_function", torch.tensor, args=([0.1],), kwargs={}
49+
)
50+
zero_point = graph.create_node(
51+
"call_function", torch.tensor, args=([0],), kwargs={}
52+
)
53+
54+
# Dequantize node
55+
dequant_node = graph.create_node(
56+
"call_function",
57+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
58+
args=(input_node, scale, zero_point, torch.int8),
59+
kwargs={},
60+
)
61+
62+
# Clone node.
63+
# Use Edge op as this is an executorch pass
64+
clone_node = graph.create_node(
65+
"call_function",
66+
exir_ops.edge.aten.clone.default,
67+
args=(dequant_node,),
68+
kwargs={},
69+
)
70+
71+
# Quantize node
72+
quant_node = graph.create_node(
73+
"call_function",
74+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
75+
args=(clone_node, scale, zero_point, torch.int8),
76+
kwargs={},
77+
)
78+
79+
# Linear node (using the module's linear layer)
80+
# Technically, should use quantized weight and bias
81+
# but we are just inspecting graph patterns in this test
82+
weight = graph.create_node("get_attr", "linear.weight")
83+
bias = graph.create_node("get_attr", "linear.bias")
84+
linear_node = graph.create_node(
85+
"call_function",
86+
torch.nn.functional.linear,
87+
args=(quant_node, weight, bias),
88+
kwargs={},
89+
)
90+
91+
# Output
92+
graph.output(linear_node)
93+
94+
# Create a GraphModule with our custom graph
95+
gm = GraphModule(module, graph)
96+
97+
# Verify we have the expected nodes before transformation using FileCheck
98+
FileCheck().check(
99+
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
100+
).check(
101+
"executorch_exir_dialects_edge__ops_aten_clone_default",
102+
).check(
103+
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
104+
).check(
105+
"torch._C._nn.linear",
106+
).run(
107+
gm.code
108+
)
109+
110+
# Apply the transform
111+
transformed_gm = RemoveCloneOpsTransform()(gm).graph_module
112+
113+
# Verify the dq -> clone -> q pattern is removed and linear op is still present using FileCheck
114+
FileCheck().check_not(
115+
"executorch_exir_dialects_edge__ops_aten_clone_default"
116+
).check_not("quantized_decomposed.dequantize_per_tensor.default").check_not(
117+
"quantized_decomposed.quantize_per_tensor.default"
118+
).check_count(
119+
"torch._C._nn.linear",
120+
1,
121+
exactly=True,
122+
).run(
123+
transformed_gm.code
124+
)
125+
126+
127+
if __name__ == "__main__":
128+
unittest.main()

runtime/executor/method.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ Error Method::resolve_operator(
670670
size_t kernel_index,
671671
InstructionArgs args,
672672
size_t n_args) {
673-
// TODO(T153505381, T153506819) Investigate optimizing this function for both
673+
// TODO(T153506819) Investigate optimizing this function for both
674674
// space and time.
675675

676676
// resolve name
@@ -691,8 +691,16 @@ Error Method::resolve_operator(
691691
}
692692

693693
// resolve tensor meta
694-
auto method_allocator = memory_manager_->method_allocator();
695-
TensorMeta* meta = method_allocator->allocateList<TensorMeta>(n_args);
694+
// Since temp allocator can be freed, we optimistically
695+
// try to use that allocator first.
696+
auto allocator = memory_manager_->temp_allocator();
697+
// However, it does not have to be provided, so if it
698+
// is not provided (or an empty one is provided), we
699+
// fall back to the method allocator.
700+
if (allocator == nullptr || allocator->size() == 0) {
701+
allocator = memory_manager_->method_allocator();
702+
}
703+
TensorMeta* meta = allocator->allocateList<TensorMeta>(n_args);
696704
if (meta == nullptr) {
697705
return Error::MemoryAllocationFailed;
698706
}
@@ -705,8 +713,7 @@ Error Method::resolve_operator(
705713
auto tensor = eval->toTensor();
706714
meta[count].dtype_ = tensor.scalar_type();
707715
executorch::aten::DimOrderType* dim_order_ptr =
708-
method_allocator->allocateList<executorch::aten::DimOrderType>(
709-
tensor.dim());
716+
allocator->allocateList<executorch::aten::DimOrderType>(tensor.dim());
710717
if (dim_order_ptr == nullptr) {
711718
return Error::MemoryAllocationFailed;
712719
}

0 commit comments

Comments
 (0)