Skip to content

Commit 6835f61

Browse files
committed
Optimize transposes in XNNPACK partition
1 parent 8b91628 commit 6835f61

File tree

5 files changed

+188
-17
lines changed

5 files changed

+188
-17
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
FuseBatchNormWithConvPass,
2626
)
2727
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
28+
29+
from executorch.backends.xnnpack._passes.remove_redundant_ops_pass import (
30+
RemoveRedundantOpsPass,
31+
)
2832
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
2933
TagImplicitQDqPass,
3034
)
@@ -70,6 +74,7 @@ def __init__(
7074
Conv1dUnsqueezePass,
7175
PReLUReshapePass,
7276
ChannelsLastTaggedReshapePass,
77+
RemoveRedundantOpsPass,
7378
TagImplicitQDqPass,
7479
]
7580
else:

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from enum import Enum
78
from typing import Optional, Tuple
89

910
import torch
@@ -14,6 +15,11 @@
1415
from executorch.exir.pass_base import PassResult
1516

1617

18+
class InputDimOrder(Enum):
19+
NCHW = 1
20+
NHWC = 2
21+
22+
1723
# TODO(T151254305) use subgraph_rewriter
1824
class ChannelsLastTaggedReshapePass(XNNPACKPass):
1925
"""
@@ -84,11 +90,13 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8490
def mark_as_nchw_node(self, node: torch.fx.Node) -> None:
8591
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = False
8692

87-
def is_nhwc_node(self, node: torch.fx.Node) -> bool:
93+
@staticmethod
94+
def is_nhwc_node(node: torch.fx.Node) -> bool:
8895
return node.meta.get(ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False)
8996

90-
def is_nchw_node(self, node: torch.fx.Node) -> bool:
91-
return not self.is_nhwc_node(node)
97+
@staticmethod
98+
def is_nchw_node(node: torch.fx.Node) -> bool:
99+
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
92100

93101
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94102
return (
@@ -114,7 +122,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
114122
is_nchw_constant = (
115123
is_param_node(self.exported_program, node)
116124
and (ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in node.meta)
117-
and (self.is_nchw_node(node))
125+
and (ChannelsLastTaggedReshapePass.is_nchw_node(node))
118126
)
119127
return is_4d and not is_nchw_constant
120128

@@ -257,6 +265,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
257265
# in that case
258266
self.make_partners(original_input, copy_node)
259267

268+
def input_dim_order(
269+
self, input_node: torch.fx.Node, input_order: InputDimOrder
270+
) -> bool:
271+
if input_node.name == "x":
272+
return (
273+
input_node.meta["val"].is_contiguous()
274+
if input_order == InputDimOrder.NCHW
275+
else not input_node.meta["val"].is_contiguous()
276+
)
277+
else:
278+
return (
279+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
280+
if input_order == InputDimOrder.NCHW
281+
else ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
282+
)
283+
260284
def input_to_nhwc(
261285
self,
262286
graph_module: torch.fx.GraphModule,
@@ -266,7 +290,7 @@ def input_to_nhwc(
266290
if is_param_node(self.exported_program, input_node):
267291
if (
268292
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
269-
and self.is_nchw_node(input_node)
293+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
270294
):
271295
# This constant data tensor has been used somewhere else
272296
# in NCHW format so we can't use it here in NHWC format
@@ -277,10 +301,7 @@ def input_to_nhwc(
277301
# serializing graph, but don't do anything else here
278302
self.mark_as_nhwc_node(input_node)
279303

280-
if input_node.name == "x":
281-
if not input_node.meta["val"][0].is_contiguous():
282-
return
283-
elif self.is_nhwc_node(input_node):
304+
if self.input_dim_order(input_node, InputDimOrder.NHWC):
284305
return
285306

286307
if not self.can_be_converted_to_nhwc(input_node):
@@ -332,7 +353,7 @@ def input_to_nchw(
332353
if is_param_node(self.exported_program, input_node):
333354
if (
334355
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
335-
and self.is_nhwc_node(input_node)
356+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
336357
):
337358
# This constant data tensor has been used somewhere else
338359
# in NHWC format so we can't use it here in NCHW format
@@ -344,10 +365,7 @@ def input_to_nchw(
344365
# do anything else here
345366
self.mark_as_nchw_node(input_node)
346367

347-
if input_node.name == "x":
348-
if input_node.meta["val"].is_contiguous():
349-
return
350-
elif self.is_nchw_node(input_node):
368+
if self.input_dim_order(input_node, InputDimOrder.NCHW):
351369
return
352370

353371
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
@@ -391,7 +409,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
391409
self.input_to_nhwc(graph_module, node.args[0], node)
392410

393411
for input_node in node.all_input_nodes[1:]:
394-
if self.is_nhwc_node(input_node):
412+
if ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
395413
raise AssertionError(
396414
f"Expected {input_node} to be NCHW in channels last reshape pass"
397415
)
@@ -409,7 +427,8 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
409427
# The node can have inputs in any format (but all must be the
410428
# same format)
411429
is_or_isnt_nhwc_node = [
412-
self.is_nhwc_node(input_node) for input_node in node.all_input_nodes
430+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
431+
for input_node in node.all_input_nodes
413432
]
414433
if all(is_or_isnt_nhwc_node):
415434
# All inputs are nhwc so this node's output is nhwc too
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ChannelsLastTaggedReshapePass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import PassResult
12+
13+
class RemoveRedundantOpsPass(XNNPACKPass):
14+
def call(self, graph_module: torch.fx.GraphModule):
15+
graph = graph_module.graph
16+
original_nodes = list(graph.nodes)
17+
18+
# Store first subsequent visitation of to_copy node
19+
prev = None
20+
for node in original_nodes:
21+
if len(node.all_input_nodes) == 0:
22+
continue
23+
24+
# If we encounter a to_copy node, check if it is preceded by an opposite to_copy node
25+
if node.target == exir_ops.edge.aten._to_copy.default:
26+
if prev and ChannelsLastTaggedReshapePass.is_nchw_node(prev) != ChannelsLastTaggedReshapePass.is_nchw_node(node):
27+
# If we find an opposite to_copy node, remove both nodes
28+
prevPrev = prev.args[0]
29+
30+
for user in node.users.copy():
31+
user.replace_input_with(node, prevPrev)
32+
33+
graph.erase_node(node)
34+
graph.erase_node(prev)
35+
36+
prev = None
37+
continue
38+
prev = node
39+
else:
40+
prev = None
41+
42+
graph_module.recompile()
43+
44+
# Since we are overriding "call", we need to call the parent's "call"
45+
# to retrace the graph and regenerate metadata
46+
graph_module = super().call(graph_module).graph_module
47+
48+
return PassResult(graph_module, True)

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def forward(self, x):
147147
def test_conv_linear_dim_order_swap_partitioner(self):
148148
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))
149149

150-
151150
def test_qs8_channels_last_tagged_reshape_pass(self):
152151
for module, num_reshape in self.modules.items():
153152
(
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.remove_redundant_ops_pass import (
11+
RemoveRedundantOpsPass,
12+
)
13+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
14+
ChannelsLastTaggedReshapePass,
15+
)
16+
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
17+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
18+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
19+
20+
21+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
22+
PassStage = RunPasses([DimOrderOpsRevertPass,
23+
ConvertToLinearPass,
24+
ChannelsLastTaggedReshapePass,
25+
RemoveRedundantOpsPass])
26+
27+
def setUp(self):
28+
torch._dynamo.reset()
29+
30+
def run_tester(self, module, inputs):
31+
tester = Tester(
32+
module.eval(),
33+
inputs,
34+
)
35+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
36+
37+
class ChannelsLastToContiguous(torch.nn.Module):
38+
def __init__(self):
39+
super().__init__()
40+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
41+
self.linear1 = torch.nn.Linear(4, 3)
42+
43+
def forward(self, x):
44+
y = self.linear1(x)
45+
y = y.to(memory_format=torch.channels_last)
46+
y = y.to(memory_format=torch.contiguous_format)
47+
y = y.to(memory_format=torch.channels_last)
48+
y = y.to(memory_format=torch.contiguous_format)
49+
y = y.to(memory_format=torch.channels_last)
50+
y = y.to(memory_format=torch.contiguous_format)
51+
return self.conv1(y)
52+
53+
ChannelsLastToContiguousModule = ChannelsLastToContiguous()
54+
55+
class ContiguousToChannelsLast(torch.nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
59+
self.linear1 = torch.nn.Linear(4, 3)
60+
61+
def forward(self, x):
62+
y = self.linear1(x)
63+
y = y.to(memory_format=torch.contiguous_format)
64+
y = y.to(memory_format=torch.channels_last)
65+
y = y.to(memory_format=torch.contiguous_format)
66+
y = y.to(memory_format=torch.channels_last)
67+
y = y.to(memory_format=torch.contiguous_format)
68+
y = y.to(memory_format=torch.channels_last)
69+
70+
return self.conv1(y)
71+
72+
ContiguousToChannelsLastModule = ContiguousToChannelsLast()
73+
74+
def test_redundant_to_copy_op_removal(self):
75+
(
76+
Tester(self.ChannelsLastToContiguousModule, (torch.randn(1, 3, 6, 4),))
77+
.export()
78+
.to_edge()
79+
.run_passes(self.PassStage)
80+
.check_count(
81+
{
82+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
83+
}
84+
)
85+
.run_method_and_compare_outputs()
86+
)
87+
88+
def test_redundant_to_copy_op_removal_2(self):
89+
(
90+
Tester(self.ContiguousToChannelsLastModule, (torch.randn(1, 3, 6, 4),))
91+
.export()
92+
.to_edge()
93+
.run_passes(self.PassStage)
94+
.check_count(
95+
{
96+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 1,
97+
}
98+
)
99+
.run_method_and_compare_outputs()
100+
)

0 commit comments

Comments
 (0)