Skip to content

Commit 2b4643d

Browse files
committed
Optimize transposes in XNNPACK partition
1 parent fa7e730 commit 2b4643d

File tree

4 files changed

+234
-14
lines changed

4 files changed

+234
-14
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
FuseBatchNormWithConvPass,
2626
)
2727
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
28+
29+
from executorch.backends.xnnpack._passes.remove_redundant_ops_pass import (
30+
RemoveRedundantOpsPass,
31+
)
2832
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
2933
TagImplicitQDqPass,
3034
)
@@ -70,6 +74,7 @@ def __init__(
7074
Conv1dUnsqueezePass,
7175
PReLUReshapePass,
7276
ChannelsLastTaggedReshapePass,
77+
RemoveRedundantOpsPass,
7378
TagImplicitQDqPass,
7479
]
7580
else:

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from enum import Enum
78
from typing import Optional, Tuple
89

910
import torch
@@ -14,6 +15,11 @@
1415
from executorch.exir.pass_base import PassResult
1516

1617

18+
class InputDimOrder(Enum):
19+
NCHW = 1
20+
NHWC = 2
21+
22+
1723
# TODO(T151254305) use subgraph_rewriter
1824
class ChannelsLastTaggedReshapePass(XNNPACKPass):
1925
"""
@@ -84,11 +90,19 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8490
def mark_as_nchw_node(self, node: torch.fx.Node) -> None:
8591
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = False
8692

87-
def is_nhwc_node(self, node: torch.fx.Node) -> bool:
93+
def tag_node(self, node: torch.fx.Node) -> None:
94+
if node.kwargs["memory_format"] == torch.channels_last:
95+
self.mark_as_nhwc_node(node)
96+
else:
97+
self.mark_as_nchw_node(node)
98+
99+
@staticmethod
100+
def is_nhwc_node(node: torch.fx.Node) -> bool:
88101
return node.meta.get(ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False)
89102

90-
def is_nchw_node(self, node: torch.fx.Node) -> bool:
91-
return not self.is_nhwc_node(node)
103+
@staticmethod
104+
def is_nchw_node(node: torch.fx.Node) -> bool:
105+
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
92106

93107
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94108
return node.target in self.memory_sensitive_ops_nhwc
@@ -106,7 +120,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
106120
is_nchw_constant = (
107121
is_param_node(self.exported_program, node)
108122
and (ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in node.meta)
109-
and (self.is_nchw_node(node))
123+
and (ChannelsLastTaggedReshapePass.is_nchw_node(node))
110124
)
111125
return is_4d and not is_nchw_constant
112126

@@ -249,6 +263,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
249263
# in that case
250264
self.make_partners(original_input, copy_node)
251265

266+
def input_dim_order(
267+
self, input_node: torch.fx.Node, input_order: InputDimOrder
268+
) -> bool:
269+
if input_node.op == "placeholder":
270+
return (
271+
input_node.meta["val"].is_contiguous()
272+
if input_order == InputDimOrder.NCHW
273+
else not input_node.meta["val"].is_contiguous()
274+
)
275+
else:
276+
return (
277+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
278+
if input_order == InputDimOrder.NCHW
279+
else ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
280+
)
281+
252282
def input_to_nhwc(
253283
self,
254284
graph_module: torch.fx.GraphModule,
@@ -258,7 +288,7 @@ def input_to_nhwc(
258288
if is_param_node(self.exported_program, input_node):
259289
if (
260290
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
261-
and self.is_nchw_node(input_node)
291+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
262292
):
263293
# This constant data tensor has been used somewhere else
264294
# in NCHW format so we can't use it here in NHWC format
@@ -275,6 +305,9 @@ def input_to_nhwc(
275305
elif self.is_nhwc_node(input_node):
276306
return
277307

308+
if self.input_dim_order(input_node, InputDimOrder.NHWC):
309+
return
310+
278311
if not self.can_be_converted_to_nhwc(input_node):
279312
raise AssertionError(
280313
"Attempting to convert non-NHWC compatible node to NHWC"
@@ -302,6 +335,7 @@ def input_to_nhwc(
302335
args=(input_node,),
303336
memory_format=torch.channels_last,
304337
)
338+
self.mark_as_nhwc_node(input_node_nhwc)
305339

306340
if is_dynamic_input:
307341
# Replace downstream input_nodes with NHWC node
@@ -324,7 +358,7 @@ def input_to_nchw(
324358
if is_param_node(self.exported_program, input_node):
325359
if (
326360
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
327-
and self.is_nhwc_node(input_node)
361+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
328362
):
329363
# This constant data tensor has been used somewhere else
330364
# in NHWC format so we can't use it here in NCHW format
@@ -342,6 +376,9 @@ def input_to_nchw(
342376
elif self.is_nchw_node(input_node):
343377
return
344378

379+
if self.input_dim_order(input_node, InputDimOrder.NCHW):
380+
return
381+
345382
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
346383
# Already has an associated NCHW node
347384
input_node_nchw = input_node.meta[
@@ -356,6 +393,7 @@ def input_to_nchw(
356393
args=(input_node,),
357394
memory_format=torch.contiguous_format,
358395
)
396+
self.mark_as_nchw_node(input_node_nchw)
359397

360398
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
361399
graph_module=graph_module,
@@ -383,10 +421,9 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
383421
elif self.requires_nhwc_input(node):
384422
# Nodes which enter this branch are ones that require their
385423
# first input to be nhwc. This makes this node's output nhwc too
386-
387424
self.input_to_nhwc(graph_module, node.args[0], node)
388-
for input_node in node.all_input_nodes:
389-
if input_node.op == "placeholder" and self.is_nhwc_node(input_node):
425+
for input_node in node.all_input_nodes[1:]:
426+
if input_node.op == "placeholder" and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
390427
raise AssertionError(
391428
f"Expected {input_node} to be NCHW in channels last reshape pass"
392429
)
@@ -396,15 +433,13 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
396433
for input_node in node.all_input_nodes:
397434
self.input_to_nchw(graph_module, input_node, node)
398435
elif node.target == exir_ops.edge.aten._to_copy.default:
399-
if node.kwargs["memory_format"] == torch.channels_last:
400-
self.mark_as_nhwc_node(node)
401-
else:
402-
self.mark_as_nchw_node(node)
436+
self.tag_node(node)
403437
else:
404438
# The node can have inputs in any format (but all must be the
405439
# same format)
406440
is_or_isnt_nhwc_node = [
407-
self.is_nhwc_node(input_node) for input_node in node.all_input_nodes
441+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
442+
for input_node in node.all_input_nodes
408443
]
409444
if all(is_or_isnt_nhwc_node):
410445
# All inputs are nhwc so this node's output is nhwc too
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
9+
ChannelsLastTaggedReshapePass,
10+
)
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import PassResult
14+
15+
16+
class RemoveRedundantOpsPass(XNNPACKPass):
17+
def call(self, graph_module: torch.fx.GraphModule):
18+
graph = graph_module.graph
19+
original_nodes = list(graph.nodes)
20+
21+
for node in original_nodes:
22+
if len(node.all_input_nodes) == 0:
23+
continue
24+
25+
# If we encounter a to_copy node, check if its input is also a to_copy node with opposite format
26+
if node.target == exir_ops.edge.aten._to_copy.default:
27+
input_node = node.args[0]
28+
if (
29+
input_node.target == exir_ops.edge.aten._to_copy.default
30+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
31+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
32+
):
33+
# If we find an opposite to_copy node, remove both nodes
34+
original_input = input_node.args[0]
35+
36+
for user in node.users.copy():
37+
user.replace_input_with(node, original_input)
38+
39+
graph.erase_node(node)
40+
graph.erase_node(input_node)
41+
42+
graph_module.recompile()
43+
44+
# Since we are overriding "call", we need to call the parent's "call"
45+
# to retrace the graph and regenerate metadata
46+
graph_module = super().call(graph_module).graph_module
47+
48+
return PassResult(graph_module, True)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
11+
ChannelsLastTaggedReshapePass,
12+
)
13+
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
14+
from executorch.backends.xnnpack._passes.remove_redundant_ops_pass import (
15+
RemoveRedundantOpsPass,
16+
)
17+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
18+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
19+
20+
21+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
22+
PassStage = RunPasses(
23+
[
24+
DimOrderOpsRevertPass,
25+
ConvertToLinearPass,
26+
ChannelsLastTaggedReshapePass,
27+
RemoveRedundantOpsPass,
28+
]
29+
)
30+
31+
def setUp(self):
32+
torch._dynamo.reset()
33+
34+
def run_tester(self, module, inputs):
35+
tester = Tester(
36+
module.eval(),
37+
inputs,
38+
)
39+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
40+
41+
class ChannelsLastToContiguous(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
45+
self.linear1 = torch.nn.Linear(4, 3)
46+
47+
def forward(self, x):
48+
y = self.linear1(x)
49+
y = y.to(memory_format=torch.channels_last)
50+
y = y.to(memory_format=torch.contiguous_format)
51+
y = y.to(memory_format=torch.channels_last)
52+
y = y.to(memory_format=torch.contiguous_format)
53+
y = y.to(memory_format=torch.channels_last)
54+
y = y.to(memory_format=torch.contiguous_format)
55+
return self.conv1(y)
56+
57+
ChannelsLastToContiguousModule = ChannelsLastToContiguous()
58+
59+
class ContiguousToChannelsLast(torch.nn.Module):
60+
def __init__(self):
61+
super().__init__()
62+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
63+
self.linear1 = torch.nn.Linear(4, 3)
64+
65+
def forward(self, x):
66+
y = self.linear1(x)
67+
y = y.to(memory_format=torch.contiguous_format)
68+
y = y.to(memory_format=torch.channels_last)
69+
y = y.to(memory_format=torch.contiguous_format)
70+
y = y.to(memory_format=torch.channels_last)
71+
y = y.to(memory_format=torch.contiguous_format)
72+
y = y.to(memory_format=torch.channels_last)
73+
74+
return self.conv1(y)
75+
76+
ContiguousToChannelsLastModule = ContiguousToChannelsLast()
77+
78+
class ImplicitRedundantOpRemoval(torch.nn.Module):
79+
def __init__(self):
80+
super().__init__()
81+
self.upsample = torch.nn.Upsample(scale_factor=2, mode="nearest")
82+
self.conv = torch.nn.Conv2d(3, 3, 3)
83+
84+
def forward(self, x):
85+
y = x.to(memory_format=torch.channels_last)
86+
y = self.upsample(y)
87+
y = y.to(memory_format=torch.contiguous_format)
88+
return self.conv(y)
89+
90+
ImplicitRedundantOpRemovalModule = ImplicitRedundantOpRemoval()
91+
92+
def test_redundant_to_copy_op_removal(self):
93+
(
94+
Tester(self.ChannelsLastToContiguousModule, (torch.randn(1, 3, 6, 4),))
95+
.export()
96+
.to_edge()
97+
.run_passes(self.PassStage)
98+
.check_count(
99+
{
100+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
101+
}
102+
)
103+
.run_method_and_compare_outputs()
104+
)
105+
106+
def test_redundant_to_copy_op_removal_2(self):
107+
(
108+
Tester(self.ContiguousToChannelsLastModule, (torch.randn(1, 3, 6, 4),))
109+
.export()
110+
.to_edge()
111+
.run_passes(self.PassStage)
112+
.check_count(
113+
{
114+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 1,
115+
}
116+
)
117+
.run_method_and_compare_outputs()
118+
)
119+
120+
def test_implicit_redundant_op_removal(self):
121+
(
122+
Tester(self.ImplicitRedundantOpRemovalModule, (torch.randn(1, 3, 3, 3),))
123+
.export()
124+
.to_edge()
125+
.run_passes(self.PassStage)
126+
.check_count(
127+
{
128+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
129+
}
130+
)
131+
.run_method_and_compare_outputs()
132+
)

0 commit comments

Comments
 (0)