Skip to content

Commit 824a753

Browse files
committed
Optimize transposes in XNNPACK partition
1 parent a8e4be4 commit 824a753

File tree

8 files changed

+237
-16
lines changed

8 files changed

+237
-16
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
FuseBatchNormWithConvPass,
2626
)
2727
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
28+
29+
from executorch.backends.xnnpack._passes.remove_redundant_ops_pass import (
30+
RemoveRedundantOpsPass,
31+
)
2832
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
2933
TagImplicitQDqPass,
3034
)
@@ -70,6 +74,7 @@ def __init__(
7074
Conv1dUnsqueezePass,
7175
PReLUReshapePass,
7276
ChannelsLastTaggedReshapePass,
77+
RemoveRedundantOpsPass,
7378
TagImplicitQDqPass,
7479
]
7580
else:

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from enum import Enum
78
from typing import Optional, Tuple
89

910
import torch
@@ -14,6 +15,11 @@
1415
from executorch.exir.pass_base import PassResult
1516

1617

18+
class InputDimOrder(Enum):
19+
NCHW = 1
20+
NHWC = 2
21+
22+
1723
# TODO(T151254305) use subgraph_rewriter
1824
class ChannelsLastTaggedReshapePass(XNNPACKPass):
1925
"""
@@ -84,11 +90,19 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8490
def mark_as_nchw_node(self, node: torch.fx.Node) -> None:
8591
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = False
8692

87-
def is_nhwc_node(self, node: torch.fx.Node) -> bool:
93+
def tag_node(self, node: torch.fx.Node) -> None:
94+
if node.meta["val"].is_contiguous():
95+
self.mark_as_nchw_node(node)
96+
else:
97+
self.mark_as_nhwc_node(node)
98+
99+
@staticmethod
100+
def is_nhwc_node(node: torch.fx.Node) -> bool:
88101
return node.meta.get(ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False)
89102

90-
def is_nchw_node(self, node: torch.fx.Node) -> bool:
91-
return not self.is_nhwc_node(node)
103+
@staticmethod
104+
def is_nchw_node(node: torch.fx.Node) -> bool:
105+
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
92106

93107
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94108
return (
@@ -114,7 +128,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
114128
is_nchw_constant = (
115129
is_param_node(self.exported_program, node)
116130
and (ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in node.meta)
117-
and (self.is_nchw_node(node))
131+
and (ChannelsLastTaggedReshapePass.is_nchw_node(node))
118132
)
119133
return is_4d and not is_nchw_constant
120134

@@ -257,6 +271,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
257271
# in that case
258272
self.make_partners(original_input, copy_node)
259273

274+
def input_dim_order(
275+
self, input_node: torch.fx.Node, input_order: InputDimOrder
276+
) -> bool:
277+
if input_node.op == "placeholder":
278+
return (
279+
input_node.meta["val"].is_contiguous()
280+
if input_order == InputDimOrder.NCHW
281+
else not input_node.meta["val"].is_contiguous()
282+
)
283+
else:
284+
return (
285+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
286+
if input_order == InputDimOrder.NCHW
287+
else ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
288+
)
289+
260290
def input_to_nhwc(
261291
self,
262292
graph_module: torch.fx.GraphModule,
@@ -266,7 +296,7 @@ def input_to_nhwc(
266296
if is_param_node(self.exported_program, input_node):
267297
if (
268298
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
269-
and self.is_nchw_node(input_node)
299+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
270300
):
271301
# This constant data tensor has been used somewhere else
272302
# in NCHW format so we can't use it here in NHWC format
@@ -283,6 +313,9 @@ def input_to_nhwc(
283313
elif self.is_nhwc_node(input_node):
284314
return
285315

316+
if self.input_dim_order(input_node, InputDimOrder.NHWC):
317+
return
318+
286319
if not self.can_be_converted_to_nhwc(input_node):
287320
raise AssertionError(
288321
"Attempting to convert non-NHWC compatible node to NHWC"
@@ -310,6 +343,7 @@ def input_to_nhwc(
310343
args=(input_node,),
311344
memory_format=torch.channels_last,
312345
)
346+
self.mark_as_nhwc_node(input_node_nhwc)
313347

314348
if is_dynamic_input:
315349
# Replace downstream input_nodes with NHWC node
@@ -332,7 +366,7 @@ def input_to_nchw(
332366
if is_param_node(self.exported_program, input_node):
333367
if (
334368
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
335-
and self.is_nhwc_node(input_node)
369+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
336370
):
337371
# This constant data tensor has been used somewhere else
338372
# in NHWC format so we can't use it here in NCHW format
@@ -350,6 +384,9 @@ def input_to_nchw(
350384
elif self.is_nchw_node(input_node):
351385
return
352386

387+
if self.input_dim_order(input_node, InputDimOrder.NCHW):
388+
return
389+
353390
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
354391
# Already has an associated NCHW node
355392
input_node_nchw = input_node.meta[
@@ -364,6 +401,7 @@ def input_to_nchw(
364401
args=(input_node,),
365402
memory_format=torch.contiguous_format,
366403
)
404+
self.mark_as_nchw_node(input_node_nchw)
367405

368406
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
369407
graph_module=graph_module,
@@ -391,7 +429,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
391429
self.input_to_nhwc(graph_module, node.args[0], node)
392430

393431
for input_node in node.all_input_nodes[1:]:
394-
if self.is_nhwc_node(input_node):
432+
if ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
395433
raise AssertionError(
396434
f"Expected {input_node} to be NCHW in channels last reshape pass"
397435
)
@@ -401,15 +439,13 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
401439
for input_node in node.all_input_nodes:
402440
self.input_to_nchw(graph_module, input_node, node)
403441
elif node.target == exir_ops.edge.aten._to_copy.default:
404-
if node.meta["val"].is_contiguous():
405-
self.mark_as_nchw_node(node)
406-
else:
407-
self.mark_as_nhwc_node(node)
442+
self.tag_node(node)
408443
else:
409444
# The node can have inputs in any format (but all must be the
410445
# same format)
411446
is_or_isnt_nhwc_node = [
412-
self.is_nhwc_node(input_node) for input_node in node.all_input_nodes
447+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
448+
for input_node in node.all_input_nodes
413449
]
414450
if all(is_or_isnt_nhwc_node):
415451
# All inputs are nhwc so this node's output is nhwc too
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
9+
ChannelsLastTaggedReshapePass,
10+
)
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import PassResult
14+
15+
16+
class RemoveRedundantOpsPass(XNNPACKPass):
17+
def call(self, graph_module: torch.fx.GraphModule):
18+
graph = graph_module.graph
19+
original_nodes = list(graph.nodes)
20+
21+
for node in original_nodes:
22+
if len(node.all_input_nodes) == 0:
23+
continue
24+
25+
# If we encounter a to_copy node, check if its input is also a to_copy node with opposite format
26+
if node.target == exir_ops.edge.aten._to_copy.default:
27+
input_node = node.args[0]
28+
if (
29+
input_node.target == exir_ops.edge.aten._to_copy.default
30+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
31+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
32+
):
33+
# If we find an opposite to_copy node, remove both nodes
34+
original_input = input_node.args[0]
35+
36+
for user in node.users.copy():
37+
user.replace_input_with(node, original_input)
38+
39+
graph.erase_node(node)
40+
graph.erase_node(input_node)
41+
42+
graph_module.recompile()
43+
44+
# Since we are overriding "call", we need to call the parent's "call"
45+
# to retrace the graph and regenerate metadata
46+
graph_module = super().call(graph_module).graph_module
47+
48+
return PassResult(graph_module, True)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
11+
ChannelsLastTaggedReshapePass,
12+
)
13+
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
14+
from executorch.backends.xnnpack._passes.remove_redundant_ops_pass import (
15+
RemoveRedundantOpsPass,
16+
)
17+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
18+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
19+
20+
21+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
22+
PassStage = RunPasses(
23+
[
24+
DimOrderOpsRevertPass,
25+
ConvertToLinearPass,
26+
ChannelsLastTaggedReshapePass,
27+
RemoveRedundantOpsPass,
28+
]
29+
)
30+
31+
def setUp(self):
32+
torch._dynamo.reset()
33+
34+
def run_tester(self, module, inputs):
35+
tester = Tester(
36+
module.eval(),
37+
inputs,
38+
)
39+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
40+
41+
class ChannelsLastToContiguous(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
45+
self.linear1 = torch.nn.Linear(4, 3)
46+
47+
def forward(self, x):
48+
y = self.linear1(x)
49+
y = y.to(memory_format=torch.channels_last)
50+
y = y.to(memory_format=torch.contiguous_format)
51+
y = y.to(memory_format=torch.channels_last)
52+
y = y.to(memory_format=torch.contiguous_format)
53+
y = y.to(memory_format=torch.channels_last)
54+
y = y.to(memory_format=torch.contiguous_format)
55+
return self.conv1(y)
56+
57+
ChannelsLastToContiguousModule = ChannelsLastToContiguous()
58+
59+
class ContiguousToChannelsLast(torch.nn.Module):
60+
def __init__(self):
61+
super().__init__()
62+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
63+
self.linear1 = torch.nn.Linear(4, 3)
64+
65+
def forward(self, x):
66+
y = self.linear1(x)
67+
y = y.to(memory_format=torch.contiguous_format)
68+
y = y.to(memory_format=torch.channels_last)
69+
y = y.to(memory_format=torch.contiguous_format)
70+
y = y.to(memory_format=torch.channels_last)
71+
y = y.to(memory_format=torch.contiguous_format)
72+
y = y.to(memory_format=torch.channels_last)
73+
74+
return self.conv1(y)
75+
76+
ContiguousToChannelsLastModule = ContiguousToChannelsLast()
77+
78+
class ImplicitRedundantOpRemoval(torch.nn.Module):
79+
def __init__(self):
80+
super().__init__()
81+
self.upsample = torch.nn.Upsample(scale_factor=2, mode="nearest")
82+
self.conv = torch.nn.Conv2d(3, 3, 3)
83+
84+
def forward(self, x):
85+
y = x.to(memory_format=torch.channels_last)
86+
y = self.upsample(y)
87+
y = y.to(memory_format=torch.contiguous_format)
88+
return self.conv(y)
89+
90+
ImplicitRedundantOpRemovalModule = ImplicitRedundantOpRemoval()
91+
92+
def test_redundant_to_copy_op_removal(self):
93+
(
94+
Tester(self.ChannelsLastToContiguousModule, (torch.randn(1, 3, 6, 4),))
95+
.export()
96+
.to_edge()
97+
.run_passes(self.PassStage)
98+
.check_count(
99+
{
100+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
101+
}
102+
)
103+
.run_method_and_compare_outputs()
104+
)
105+
106+
def test_redundant_to_copy_op_removal_2(self):
107+
(
108+
Tester(self.ContiguousToChannelsLastModule, (torch.randn(1, 3, 6, 4),))
109+
.export()
110+
.to_edge()
111+
.run_passes(self.PassStage)
112+
.check_count(
113+
{
114+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 1,
115+
}
116+
)
117+
.run_method_and_compare_outputs()
118+
)
119+
120+
def test_implicit_redundant_op_removal(self):
121+
(
122+
Tester(self.ImplicitRedundantOpRemovalModule, (torch.randn(1, 3, 3, 3),))
123+
.export()
124+
.to_edge()
125+
.run_passes(self.PassStage)
126+
.check_count(
127+
{
128+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
129+
}
130+
)
131+
.run_method_and_compare_outputs()
132+
)

extension/llm/tokenizers

Submodule eigen updated from a39ade4 to 7294434

third-party/ao

Submodule ao updated 100 files

0 commit comments

Comments
 (0)