Skip to content

Commit b3b7a98

Browse files
authored
Remove no-op clones in xnnpack (#15884)
ATen clone ops can end up in the graph from a few sources. Since the graph is functional, we don't actually need these and they are slow. This PR runs the no-op clone removal pass for XNNPACK. In addition to this, I ran into an issue where XNNPACK delegate doesn't currently handle inputs being directly forwarded to partition outputs. There has to be at least one operator. To solve this, I updated the removal pass to leave these clone ops in and added copy support in the XNN delegate to direct copy to the output. In the long-run, I want to remove these no-ops higher up as part of to_edge, but this requires alignment and changes with a few more backends. See #15838. But resolving for XNNPACK will mitigate the issue for CPU models, at least. Differential Revision: D87405074
1 parent 351815f commit b3b7a98

File tree

12 files changed

+293
-1
lines changed

12 files changed

+293
-1
lines changed

backends/transforms/remove_clone_ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ class RemoveCloneOpsTransform(ExportPass):
2525
exir_ops.edge.dim_order_ops._clone_dim_order.default,
2626
}
2727

28-
def __init__(self) -> None:
28+
def __init__(self, preserve_input_output_copies: bool = False) -> None:
2929
super().__init__()
30+
self._preserve_input_output_copies = preserve_input_output_copies
3031

3132
def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3233
dequant_nodes = []
@@ -38,6 +39,11 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3839
if self._is_non_identity_clone(n):
3940
continue
4041

42+
# If preserve_input_output_copies is set, don't remove clones that directly
43+
# copy from input to output.
44+
if self._is_input_output_copy(n) and self._preserve_input_output_copies:
45+
continue
46+
4147
to_be_removed = n
4248
for user_n in list(n.users.keys()):
4349
user_n.replace_input_with(n, n.args[0])
@@ -76,3 +82,16 @@ def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
7682
)
7783

7884
return False
85+
86+
def _is_input_output_copy(self, node: torch.fx.Node) -> bool:
87+
"""Return True if the node input is a graph input and output goes into an output node."""
88+
89+
input_node = node.args[0]
90+
if input_node.op != "placeholder":
91+
return False
92+
93+
for users in node.users:
94+
if users.op == "output":
95+
return True
96+
97+
return False

backends/xnnpack/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ runtime.python_library(
88
deps = [
99
"//caffe2:torch",
1010
"//executorch/backends/transforms:addmm_mm_to_linear",
11+
"//executorch/backends/transforms:remove_clone_ops",
1112
"//executorch/backends/transforms:lib",
1213
"//executorch/backends/xnnpack/partition:partitioner_graphs",
1314
"//executorch/backends/xnnpack/serialization:xnnpack_schema",

backends/xnnpack/_passes/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
from typing import List, Optional, Type
810

11+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
12+
913
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
1014

1115
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
@@ -42,6 +46,11 @@
4246
from torch.export import ExportedProgram
4347

4448

49+
class XNNPACKRemoveCloneOpsTransform(RemoveCloneOpsTransform):
50+
def __init__(self):
51+
super().__init__(preserve_input_output_copies=True)
52+
53+
4554
class XNNPACKPassManager:
4655
def __init__(
4756
self,
@@ -58,6 +67,7 @@ def __init__(
5867
if not passes:
5968
# All the XNNPACK passes
6069
self.passes = [
70+
XNNPACKRemoveCloneOpsTransform,
6171
# TODO - remove this pass once we have a better support for dim_order ops lowering
6272
DimOrderOpsRevertPass,
6373
ConvertToUpsampleBilinear2d,

backends/xnnpack/operators/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
from . import ( # noqa
810
node_visitor,
911
op_abs,
@@ -14,6 +16,7 @@
1416
op_cat,
1517
op_ceiling,
1618
op_clamp,
19+
op_clone,
1720
op_conv2d,
1821
op_div,
1922
op_dynamic_dequantize_ops,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Dict
10+
11+
import torch
12+
from executorch.backends.xnnpack.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
17+
XNNCopy,
18+
XNNGraph,
19+
XNode,
20+
)
21+
from executorch.backends.xnnpack.utils.utils import get_input_node
22+
23+
24+
@register_node_visitor
25+
class CloneVisitor(NodeVisitor):
26+
target = "aten.clone.default"
27+
28+
def __init__(self, *args) -> None:
29+
super().__init__(*args)
30+
31+
def define_node(
32+
self,
33+
node: torch.fx.Node,
34+
xnn_graph: XNNGraph,
35+
vals_to_ids: Dict[torch.fx.Node, int],
36+
debug_handle: int,
37+
) -> None:
38+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
39+
40+
# Sanity check that the input and output dim order are the same. We don't
41+
# handle dim order conversions yet.
42+
dim_order = node.kwargs.get("dim_order", None)
43+
input_meta = node.args[0].meta["val"]
44+
assert dim_order is None or list(input_meta.dim_order()) == dim_order
45+
46+
# input
47+
input_id = vals_to_ids[get_input_node(node, 0)]
48+
49+
# output
50+
output_id = vals_to_ids[node]
51+
52+
ser_node = XNode(
53+
xnode_union=XNNCopy(
54+
input_id=input_id,
55+
output_id=output_id,
56+
flags=0,
57+
),
58+
debug_handle=debug_handle,
59+
)
60+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
78

89
from typing import List, Type
910

@@ -22,6 +23,7 @@
2223
CatConfig,
2324
CeilConfig,
2425
ClampConfig,
26+
CloneDimOrderConfig,
2527
ConstantPadConfig,
2628
DeQuantizedPerTensorConfig,
2729
DivConfig,
@@ -77,6 +79,7 @@
7779
BMMConfig,
7880
CatConfig,
7981
CeilConfig,
82+
CloneDimOrderConfig,
8083
ConstantPadConfig,
8184
ConvolutionConfig,
8285
ClampConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,3 +643,25 @@ class SinConfig(GenericNodePartitionerConfig):
643643

644644
def supported_precision_types(self) -> List[ConfigPrecisionType]:
645645
return [ConfigPrecisionType.FP32]
646+
647+
648+
class CloneDimOrderConfig(GenericNodePartitionerConfig):
649+
target_name = "_clone_dim_order.default"
650+
651+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
652+
return [ConfigPrecisionType.FP32]
653+
654+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
655+
if not self.check_common_constraints(node, ep):
656+
return False
657+
658+
# Only partition no-op _clone_dim_order nodes (output dim order = input).
659+
# We can relax this in the future.
660+
# This is also a conservative check and doesn't consider ambiguity.
661+
dim_order = node.kwargs.get("dim_order", None)
662+
input_meta = node.args[0].meta["val"]
663+
if dim_order is not None and list(input_meta.dim_order()) != dim_order:
664+
why(node, reason="Only dim-order preserving clones are supported.")
665+
return False
666+
667+
return True

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,34 @@ Error defineBatchMatrixMultiplyNode(
14591459
return Error::Ok;
14601460
}
14611461

1462+
/*
1463+
* Defines a copy node in the XNN subgraph.
1464+
*/
1465+
Error defineCopyNode(
1466+
xnn_subgraph_t subgraph_ptr,
1467+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1468+
const NodePtr node,
1469+
const fb_xnnpack::XNNGraph* graph) noexcept {
1470+
MAYBE_UNUSED(graph);
1471+
1472+
auto graph_node = node->xnode_union_as_XNNCopy();
1473+
1474+
xnn_status status = xnn_define_copy(
1475+
subgraph_ptr,
1476+
remapped_ids.at(graph_node->input_id()),
1477+
remapped_ids.at(graph_node->output_id()),
1478+
graph_node->flags());
1479+
1480+
ET_CHECK_OR_RETURN_ERROR(
1481+
status == xnn_status_success,
1482+
Internal,
1483+
"Failed to create copy node %i with code: %s",
1484+
node->debug_handle(),
1485+
xnn_status_to_string(status));
1486+
1487+
return Error::Ok;
1488+
}
1489+
14621490
/*
14631491
Returns not Implemented Error code. This function is meant to be
14641492
called when the compiler encountes a XNodeType from the flatbuffer
@@ -1763,6 +1791,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
17631791
_DEFINE(Concatenate5)
17641792
_DEFINE(StaticSlice)
17651793
_DEFINE(BatchMatrixMultiply)
1794+
_DEFINE(Copy)
17661795
case fb_xnnpack::XNodeUnion::NONE:
17671796
default: // Adding here as a catch all, just in case
17681797
return &defineNotImplementedNode;

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ union XNodeUnion {
157157
XNNTanh: _XNNNode1x1,
158158
XNNExp: _XNNNode1x1,
159159
XNNSin: _XNNNode1x1,
160+
XNNCopy: _XNNNode1x1,
160161
}
161162

162163
union XValueUnion {

backends/xnnpack/serialization/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ union XNodeUnion {
153153
XNNTanh: _XNNNode1x1,
154154
XNNExp: _XNNNode1x1,
155155
XNNSin: _XNNNode1x1,
156+
XNNCopy: _XNNNode1x1,
156157
}
157158

158159
union XValueUnion {

0 commit comments

Comments
 (0)