Skip to content

Commit dd07f80

Browse files
committed
Optimize transposes in XNNPACK partition
1 parent 083663b commit dd07f80

File tree

4 files changed

+274
-10
lines changed

4 files changed

+274
-10
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
FuseBatchNormWithConvPass,
2626
)
2727
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
28+
29+
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
30+
RemoveRedundantCopyPass,
31+
)
2832
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
2933
TagImplicitQDqPass,
3034
)
@@ -70,6 +74,7 @@ def __init__(
7074
Conv1dUnsqueezePass,
7175
PReLUReshapePass,
7276
ChannelsLastTaggedReshapePass,
77+
RemoveRedundantCopyPass,
7378
TagImplicitQDqPass,
7479
]
7580
else:

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from enum import Enum
78
from typing import Optional, Tuple
89

910
import torch
@@ -14,6 +15,11 @@
1415
from executorch.exir.pass_base import PassResult
1516

1617

18+
class InputDimOrder(Enum):
19+
NCHW = 1
20+
NHWC = 2
21+
22+
1723
# TODO(T151254305) use subgraph_rewriter
1824
class ChannelsLastTaggedReshapePass(XNNPACKPass):
1925
"""
@@ -84,11 +90,19 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8490
def mark_as_nchw_node(self, node: torch.fx.Node) -> None:
8591
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = False
8692

87-
def is_nhwc_node(self, node: torch.fx.Node) -> bool:
93+
def tag_node(self, node: torch.fx.Node) -> None:
94+
if node.kwargs["memory_format"] == torch.channels_last:
95+
self.mark_as_nhwc_node(node)
96+
else:
97+
self.mark_as_nchw_node(node)
98+
99+
@staticmethod
100+
def is_nhwc_node(node: torch.fx.Node) -> bool:
88101
return node.meta.get(ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False)
89102

90-
def is_nchw_node(self, node: torch.fx.Node) -> bool:
91-
return not self.is_nhwc_node(node)
103+
@staticmethod
104+
def is_nchw_node(node: torch.fx.Node) -> bool:
105+
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
92106

93107
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94108
return node.target in self.memory_sensitive_ops_nhwc
@@ -106,7 +120,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
106120
is_nchw_constant = (
107121
is_param_node(self.exported_program, node)
108122
and (ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in node.meta)
109-
and (self.is_nchw_node(node))
123+
and (ChannelsLastTaggedReshapePass.is_nchw_node(node))
110124
)
111125
return is_4d and not is_nchw_constant
112126

@@ -249,6 +263,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
249263
# in that case
250264
self.make_partners(original_input, copy_node)
251265

266+
def input_dim_order(
267+
self, input_node: torch.fx.Node, input_order: InputDimOrder
268+
) -> bool:
269+
if input_node.op == "placeholder":
270+
return (
271+
input_node.meta["val"].is_contiguous()
272+
if input_order == InputDimOrder.NCHW
273+
else not input_node.meta["val"].is_contiguous()
274+
)
275+
else:
276+
return (
277+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
278+
if input_order == InputDimOrder.NCHW
279+
else ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
280+
)
281+
252282
def input_to_nhwc(
253283
self,
254284
graph_module: torch.fx.GraphModule,
@@ -258,7 +288,7 @@ def input_to_nhwc(
258288
if is_param_node(self.exported_program, input_node):
259289
if (
260290
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
261-
and self.is_nchw_node(input_node)
291+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
262292
):
263293
# This constant data tensor has been used somewhere else
264294
# in NCHW format so we can't use it here in NHWC format
@@ -275,6 +305,9 @@ def input_to_nhwc(
275305
elif self.is_nhwc_node(input_node):
276306
return
277307

308+
if self.input_dim_order(input_node, InputDimOrder.NHWC):
309+
return
310+
278311
if not self.can_be_converted_to_nhwc(input_node):
279312
raise AssertionError(
280313
"Attempting to convert non-NHWC compatible node to NHWC"
@@ -302,6 +335,7 @@ def input_to_nhwc(
302335
args=(input_node,),
303336
memory_format=torch.channels_last,
304337
)
338+
self.mark_as_nhwc_node(input_node_nhwc)
305339

306340
if is_dynamic_input:
307341
# Replace downstream input_nodes with NHWC node
@@ -324,7 +358,7 @@ def input_to_nchw(
324358
if is_param_node(self.exported_program, input_node):
325359
if (
326360
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
327-
and self.is_nhwc_node(input_node)
361+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
328362
):
329363
# This constant data tensor has been used somewhere else
330364
# in NHWC format so we can't use it here in NCHW format
@@ -342,6 +376,9 @@ def input_to_nchw(
342376
elif self.is_nchw_node(input_node):
343377
return
344378

379+
if self.input_dim_order(input_node, InputDimOrder.NCHW):
380+
return
381+
345382
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
346383
# Already has an associated NCHW node
347384
input_node_nchw = input_node.meta[
@@ -356,6 +393,7 @@ def input_to_nchw(
356393
args=(input_node,),
357394
memory_format=torch.contiguous_format,
358395
)
396+
self.mark_as_nchw_node(input_node_nchw)
359397

360398
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
361399
graph_module=graph_module,
@@ -383,10 +421,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
383421
elif self.requires_nhwc_input(node):
384422
# Nodes which enter this branch are ones that require their
385423
# first input to be nhwc. This makes this node's output nhwc too
386-
387424
self.input_to_nhwc(graph_module, node.args[0], node)
388-
for input_node in node.all_input_nodes:
389-
if input_node.op == "placeholder" and self.is_nhwc_node(input_node):
425+
for input_node in node.all_input_nodes[1:]:
426+
if (
427+
input_node.op == "placeholder"
428+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
429+
):
390430
raise AssertionError(
391431
f"Expected {input_node} to be NCHW in channels last reshape pass"
392432
)
@@ -395,11 +435,14 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
395435
# The node requires nchw inputs
396436
for input_node in node.all_input_nodes:
397437
self.input_to_nchw(graph_module, input_node, node)
438+
elif node.target == exir_ops.edge.aten._to_copy.default:
439+
self.tag_node(node)
398440
else:
399441
# The node can have inputs in any format (but all must be the
400442
# same format)
401443
is_or_isnt_nhwc_node = [
402-
self.is_nhwc_node(input_node) for input_node in node.all_input_nodes
444+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
445+
for input_node in node.all_input_nodes
403446
]
404447
if all(is_or_isnt_nhwc_node):
405448
# All inputs are nhwc so this node's output is nhwc too
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
9+
ChannelsLastTaggedReshapePass,
10+
)
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import PassResult
14+
15+
16+
class RemoveRedundantCopyPass(XNNPACKPass):
17+
def call(self, graph_module: torch.fx.GraphModule):
18+
graph = graph_module.graph
19+
original_nodes = list(graph.nodes)
20+
21+
for node in original_nodes:
22+
if len(node.all_input_nodes) == 0:
23+
continue
24+
25+
# If we encounter a to_copy node, check if its input is also a to_copy node with opposite format
26+
if node.target == exir_ops.edge.aten._to_copy.default:
27+
input_node = node.args[0]
28+
if (
29+
input_node.target == exir_ops.edge.aten._to_copy.default
30+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
31+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
32+
and len(input_node.users)
33+
== 1 # Ensure the first copy has no other users
34+
):
35+
# If we find an opposite to_copy node, remove both nodes
36+
original_input = input_node.args[0]
37+
38+
for user in node.users.copy():
39+
user.replace_input_with(node, original_input)
40+
41+
graph.erase_node(node)
42+
graph.erase_node(input_node)
43+
44+
graph_module.recompile()
45+
46+
# Since we are overriding "call", we need to call the parent's "call"
47+
# to retrace the graph and regenerate metadata
48+
graph_module = super().call(graph_module).graph_module
49+
50+
return PassResult(graph_module, True)
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
11+
ChannelsLastTaggedReshapePass,
12+
)
13+
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
14+
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
15+
RemoveRedundantCopyPass,
16+
)
17+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
18+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
19+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer
20+
21+
22+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
23+
PassStage = RunPasses(
24+
[
25+
DimOrderOpsRevertPass,
26+
ConvertToLinearPass,
27+
ChannelsLastTaggedReshapePass,
28+
RemoveRedundantCopyPass,
29+
]
30+
)
31+
32+
def setUp(self):
33+
torch._dynamo.reset()
34+
35+
def run_tester(self, module, inputs):
36+
tester = Tester(
37+
module.eval(),
38+
inputs,
39+
)
40+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
41+
42+
class ChannelsLastToContiguous(torch.nn.Module):
43+
def __init__(self):
44+
super().__init__()
45+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
46+
self.linear1 = torch.nn.Linear(4, 3)
47+
48+
def forward(self, x):
49+
y = self.linear1(x)
50+
y = y.to(memory_format=torch.channels_last)
51+
y = y.to(memory_format=torch.contiguous_format)
52+
y = y.to(memory_format=torch.channels_last)
53+
y = y.to(memory_format=torch.contiguous_format)
54+
y = y.to(memory_format=torch.channels_last)
55+
y = y.to(memory_format=torch.contiguous_format)
56+
return self.conv1(y)
57+
58+
ChannelsLastToContiguousModule = ChannelsLastToContiguous()
59+
60+
class ContiguousToChannelsLast(torch.nn.Module):
61+
def __init__(self):
62+
super().__init__()
63+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
64+
self.linear1 = torch.nn.Linear(4, 3)
65+
66+
def forward(self, x):
67+
y = self.linear1(x)
68+
y = y.to(memory_format=torch.contiguous_format)
69+
y = y.to(memory_format=torch.channels_last)
70+
y = y.to(memory_format=torch.contiguous_format)
71+
y = y.to(memory_format=torch.channels_last)
72+
y = y.to(memory_format=torch.contiguous_format)
73+
y = y.to(memory_format=torch.channels_last)
74+
75+
return self.conv1(y)
76+
77+
ContiguousToChannelsLastModule = ContiguousToChannelsLast()
78+
79+
class ImplicitRedundantOpRemoval(torch.nn.Module):
80+
def __init__(self):
81+
super().__init__()
82+
self.upsample = torch.nn.Upsample(scale_factor=2, mode="nearest")
83+
self.conv = torch.nn.Conv2d(3, 3, 3)
84+
85+
def forward(self, x):
86+
y = x.to(memory_format=torch.channels_last)
87+
y = self.upsample(y)
88+
y = y.to(memory_format=torch.contiguous_format)
89+
return self.conv(y)
90+
91+
ImplicitRedundantOpRemovalModule = ImplicitRedundantOpRemoval()
92+
93+
class QuantizableRedundantCopyModel(torch.nn.Module):
94+
def __init__(self):
95+
super().__init__()
96+
self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
97+
self.conv2 = torch.nn.Conv2d(16, 16, 3, padding=1)
98+
99+
def forward(self, x):
100+
x = self.conv1(x)
101+
102+
x = x.to(memory_format=torch.channels_last)
103+
x = x.to(memory_format=torch.contiguous_format)
104+
x = x.to(memory_format=torch.channels_last)
105+
106+
x = self.conv2(x)
107+
return x
108+
109+
QuantizableRedundantCopyModule = QuantizableRedundantCopyModel()
110+
111+
def test_redundant_to_copy_op_removal(self):
112+
(
113+
Tester(self.ChannelsLastToContiguousModule, (torch.randn(1, 3, 6, 4),))
114+
.export()
115+
.to_edge()
116+
.run_passes(self.PassStage)
117+
.check_count(
118+
{
119+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
120+
}
121+
)
122+
.run_method_and_compare_outputs()
123+
)
124+
125+
def test_redundant_to_copy_op_removal_2(self):
126+
(
127+
Tester(self.ContiguousToChannelsLastModule, (torch.randn(1, 3, 6, 4),))
128+
.export()
129+
.to_edge()
130+
.run_passes(self.PassStage)
131+
.check_count(
132+
{
133+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 1,
134+
}
135+
)
136+
.run_method_and_compare_outputs()
137+
)
138+
139+
def test_implicit_redundant_op_removal(self):
140+
(
141+
Tester(self.ImplicitRedundantOpRemovalModule, (torch.randn(1, 3, 3, 3),))
142+
.export()
143+
.to_edge()
144+
.run_passes(self.PassStage)
145+
.check_count(
146+
{
147+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
148+
}
149+
)
150+
.run_method_and_compare_outputs()
151+
)
152+
153+
def test_quantized_redundant_copy_removal(self):
154+
(
155+
Tester(self.QuantizableRedundantCopyModule, (torch.randn(1, 3, 32, 32),))
156+
.quantize()
157+
.export()
158+
.to_edge()
159+
.run_passes(self.PassStage)
160+
.check_count(
161+
{
162+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
163+
}
164+
)
165+
.run_method_and_compare_outputs()
166+
)

0 commit comments

Comments
 (0)