Skip to content

Commit 5c2a162

Browse files
committed
Optimize transposes in XNNPACK partition
1 parent 22b9e59 commit 5c2a162

File tree

4 files changed

+434
-16
lines changed

4 files changed

+434
-16
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
FuseBatchNormWithConvPass,
2626
)
2727
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
28+
29+
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
30+
RemoveRedundantCopyPass,
31+
)
2832
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
2933
TagImplicitQDqPass,
3034
)
@@ -70,6 +74,7 @@ def __init__(
7074
Conv1dUnsqueezePass,
7175
PReLUReshapePass,
7276
ChannelsLastTaggedReshapePass,
77+
RemoveRedundantCopyPass,
7378
TagImplicitQDqPass,
7479
]
7580
else:

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 89 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from enum import Enum
78
from typing import Optional, Tuple
89

910
import torch
1011
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11-
from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq
12+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_dynamic_qdq
1213
from executorch.backends.xnnpack.utils.utils import is_param_node
1314
from executorch.exir.dialects._ops import ops as exir_ops
1415
from executorch.exir.pass_base import PassResult
1516

1617

18+
class InputDimOrder(Enum):
19+
NCHW = 1
20+
NHWC = 2
21+
22+
1723
# TODO(T151254305) use subgraph_rewriter
1824
class ChannelsLastTaggedReshapePass(XNNPACKPass):
1925
"""
@@ -78,17 +84,49 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
7884
# is done
7985
PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"
8086

81-
def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
87+
@staticmethod
88+
def mark_as_nhwc_node(node: torch.fx.Node) -> None:
8289
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = True
8390

84-
def mark_as_nchw_node(self, node: torch.fx.Node) -> None:
91+
@staticmethod
92+
def mark_as_nchw_node(node: torch.fx.Node) -> None:
8593
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = False
8694

87-
def is_nhwc_node(self, node: torch.fx.Node) -> bool:
95+
def tag_node(self, node: torch.fx.Node) -> None:
96+
if node.kwargs["memory_format"] == torch.channels_last:
97+
self.mark_as_nhwc_node(node)
98+
else:
99+
self.mark_as_nchw_node(node)
100+
101+
@staticmethod
102+
def is_nhwc_node(node: torch.fx.Node) -> bool:
103+
if is_dequant(node) and len(node.all_input_nodes) > 0:
104+
quantize_node = node.args[0]
105+
if len(quantize_node.all_input_nodes) > 0:
106+
actual_node = quantize_node.args[0]
107+
if actual_node.op == "placeholder":
108+
return not actual_node.meta["val"][0].is_contiguous()
109+
else:
110+
return actual_node.meta.get(
111+
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
112+
)
113+
88114
return node.meta.get(ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False)
89115

90-
def is_nchw_node(self, node: torch.fx.Node) -> bool:
91-
return not self.is_nhwc_node(node)
116+
@staticmethod
117+
def is_nchw_node(node: torch.fx.Node) -> bool:
118+
if is_dequant(node) and len(node.all_input_nodes) > 0:
119+
quantize_node = node.args[0]
120+
if len(quantize_node.all_input_nodes) > 0:
121+
actual_node = quantize_node.args[0]
122+
if actual_node.op == "placeholder":
123+
return actual_node.meta["val"][0].is_contiguous()
124+
else:
125+
return not actual_node.meta.get(
126+
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
127+
)
128+
129+
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
92130

93131
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94132
return node.target in self.memory_sensitive_ops_nhwc
@@ -106,7 +144,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
106144
is_nchw_constant = (
107145
is_param_node(self.exported_program, node)
108146
and (ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in node.meta)
109-
and (self.is_nchw_node(node))
147+
and (ChannelsLastTaggedReshapePass.is_nchw_node(node))
110148
)
111149
return is_4d and not is_nchw_constant
112150

@@ -249,6 +287,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
249287
# in that case
250288
self.make_partners(original_input, copy_node)
251289

290+
def input_dim_order(
291+
self, input_node: torch.fx.Node, input_order: InputDimOrder
292+
) -> bool:
293+
if input_node.op == "placeholder":
294+
return (
295+
input_node.meta["val"].is_contiguous()
296+
if input_order == InputDimOrder.NCHW
297+
else not input_node.meta["val"].is_contiguous()
298+
)
299+
else:
300+
return (
301+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
302+
if input_order == InputDimOrder.NCHW
303+
else ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
304+
)
305+
252306
def input_to_nhwc(
253307
self,
254308
graph_module: torch.fx.GraphModule,
@@ -258,7 +312,7 @@ def input_to_nhwc(
258312
if is_param_node(self.exported_program, input_node):
259313
if (
260314
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
261-
and self.is_nchw_node(input_node)
315+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
262316
):
263317
# This constant data tensor has been used somewhere else
264318
# in NCHW format so we can't use it here in NHWC format
@@ -272,7 +326,10 @@ def input_to_nhwc(
272326
if input_node.op == "placeholder":
273327
if not input_node.meta["val"][0].is_contiguous():
274328
return
275-
elif self.is_nhwc_node(input_node):
329+
elif ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
330+
return
331+
332+
if self.input_dim_order(input_node, InputDimOrder.NHWC):
276333
return
277334

278335
if not self.can_be_converted_to_nhwc(input_node):
@@ -302,6 +359,8 @@ def input_to_nhwc(
302359
args=(input_node,),
303360
memory_format=torch.channels_last,
304361
)
362+
# Use static method for consistency
363+
ChannelsLastTaggedReshapePass.mark_as_nhwc_node(input_node_nhwc)
305364

306365
if is_dynamic_input:
307366
# Replace downstream input_nodes with NHWC node
@@ -324,7 +383,7 @@ def input_to_nchw(
324383
if is_param_node(self.exported_program, input_node):
325384
if (
326385
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
327-
and self.is_nhwc_node(input_node)
386+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
328387
):
329388
# This constant data tensor has been used somewhere else
330389
# in NHWC format so we can't use it here in NCHW format
@@ -339,7 +398,10 @@ def input_to_nchw(
339398
if input_node.op == "placeholder":
340399
if input_node.meta["val"].is_contiguous():
341400
return
342-
elif self.is_nchw_node(input_node):
401+
elif ChannelsLastTaggedReshapePass.is_nchw_node(input_node):
402+
return
403+
404+
if self.input_dim_order(input_node, InputDimOrder.NCHW):
343405
return
344406

345407
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
@@ -356,6 +418,7 @@ def input_to_nchw(
356418
args=(input_node,),
357419
memory_format=torch.contiguous_format,
358420
)
421+
ChannelsLastTaggedReshapePass.mark_as_nchw_node(input_node_nchw)
359422

360423
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
361424
graph_module=graph_module,
@@ -369,7 +432,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
369432
original_nodes = list(graph.nodes)
370433
for node in original_nodes:
371434
if len(node.all_input_nodes) == 0:
372-
# This node has no inputs so we don't need to change anything
435+
# This node has no inputs so we don't need to change anything, but still need to tag input nodes
436+
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
437+
if node.meta["val"].is_contiguous():
438+
self.mark_as_nchw_node(node)
439+
else:
440+
self.mark_as_nhwc_node(node)
373441
continue
374442

375443
# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
@@ -383,10 +451,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
383451
elif self.requires_nhwc_input(node):
384452
# Nodes which enter this branch are ones that require their
385453
# first input to be nhwc. This makes this node's output nhwc too
386-
387454
self.input_to_nhwc(graph_module, node.args[0], node)
388-
for input_node in node.all_input_nodes:
389-
if input_node.op == "placeholder" and self.is_nhwc_node(input_node):
455+
for input_node in node.all_input_nodes[1:]:
456+
if (
457+
input_node.op == "placeholder"
458+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
459+
):
390460
raise AssertionError(
391461
f"Expected {input_node} to be NCHW in channels last reshape pass"
392462
)
@@ -395,11 +465,14 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
395465
# The node requires nchw inputs
396466
for input_node in node.all_input_nodes:
397467
self.input_to_nchw(graph_module, input_node, node)
468+
elif node.target == exir_ops.edge.aten._to_copy.default:
469+
self.tag_node(node)
398470
else:
399471
# The node can have inputs in any format (but all must be the
400472
# same format)
401473
is_or_isnt_nhwc_node = [
402-
self.is_nhwc_node(input_node) for input_node in node.all_input_nodes
474+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
475+
for input_node in node.all_input_nodes
403476
]
404477
if all(is_or_isnt_nhwc_node):
405478
# All inputs are nhwc so this node's output is nhwc too
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
9+
ChannelsLastTaggedReshapePass,
10+
)
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
class RemoveRedundantCopyPass(XNNPACKPass):
18+
def _safe_remove_node(self, node, graph):
19+
if len(node.users) == 0:
20+
graph.erase_node(node)
21+
22+
def _try_remove_regular_redundant_to_copy(self, node, graph):
23+
"""
24+
Try to remove redundant regular to_copy operations with pattern to_copy1 -> to_copy2 with opposite memory formats
25+
"""
26+
input_node = node.args[0]
27+
28+
# Check if input is a to_copy with opposite memory format
29+
if (
30+
input_node.target == exir_ops.edge.aten._to_copy.default
31+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
32+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
33+
and len(input_node.users) == 1
34+
): # Ensure the first copy has no other users
35+
36+
# Get the original input (before the first to_copy)
37+
original_input = input_node.args[0]
38+
39+
# Replace all users of the second to_copy with the original input
40+
for user in node.users.copy():
41+
user.replace_input_with(node, original_input)
42+
43+
# Remove both to_copy nodes
44+
self._safe_remove_node(node, graph)
45+
self._safe_remove_node(input_node, graph)
46+
47+
return True
48+
elif (
49+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
50+
and ChannelsLastTaggedReshapePass.is_nhwc_node(node)
51+
) or (
52+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
53+
and ChannelsLastTaggedReshapePass.is_nchw_node(node)
54+
):
55+
# Replace all users of the second to_copy with the original input
56+
for user in node.users.copy():
57+
user.replace_input_with(node, input_node)
58+
self._safe_remove_node(node, graph)
59+
return True
60+
61+
return False
62+
63+
def _try_remove_quantized_redundant_to_copy(self, node, graph):
64+
"""
65+
Try to remove redundant to_copy operations in quantized graphs with pattern dq1 -> to_copy1 -> q1 -> dq2 -> to_copy2 -> q2
66+
"""
67+
# Check if this to_copy is followed by a quantize node
68+
if len(node.users) != 1:
69+
return False
70+
q_node = next(iter(node.users))
71+
if not is_quant(q_node):
72+
return False
73+
74+
# Check if this to_copy is preceded by a dequantize node
75+
dq_node = node.args[0]
76+
if not is_dequant(dq_node):
77+
return False
78+
79+
# Get the input to the dequantize node
80+
if len(dq_node.all_input_nodes) != 1:
81+
return False
82+
83+
prev_q_node = dq_node.args[0]
84+
85+
# Check if there's another dequantize -> to_copy -> quantize chain
86+
if not is_quant(prev_q_node) or len(prev_q_node.all_input_nodes) != 1:
87+
return False
88+
89+
# Check if there's a to_copy before the previous quantize
90+
prev_to_copy = prev_q_node.args[0]
91+
if (
92+
prev_to_copy.target == exir_ops.edge.aten._to_copy.default
93+
and ChannelsLastTaggedReshapePass.is_nchw_node(prev_to_copy)
94+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
95+
and len(prev_to_copy.users) == 1
96+
): # Ensure the first copy has no other users
97+
prev_dq_node = prev_to_copy.args[0]
98+
if not is_dequant(prev_dq_node) or len(prev_dq_node.all_input_nodes) != 1:
99+
return False
100+
101+
# Get the original input (before the first to_copy)
102+
original_input = prev_dq_node.args[0]
103+
104+
# Replace all users of the second to_copy with the original input
105+
for user in q_node.users.copy():
106+
user.replace_input_with(q_node, original_input)
107+
108+
# Remove nodes safely (only if they have no other users)
109+
self._safe_remove_node(q_node, graph)
110+
self._safe_remove_node(node, graph)
111+
self._safe_remove_node(dq_node, graph)
112+
self._safe_remove_node(prev_q_node, graph)
113+
self._safe_remove_node(prev_to_copy, graph)
114+
self._safe_remove_node(prev_dq_node, graph)
115+
elif (
116+
ChannelsLastTaggedReshapePass.is_nhwc_node(prev_to_copy)
117+
and ChannelsLastTaggedReshapePass.is_nhwc_node(node)
118+
) or (
119+
ChannelsLastTaggedReshapePass.is_nchw_node(prev_to_copy)
120+
and ChannelsLastTaggedReshapePass.is_nchw_node(node)
121+
):
122+
# Remove node and the q/dq around it only
123+
# Get the original quantized tensor (input to dq_node)
124+
original_q_tensor = dq_node.args[0]
125+
126+
# Replace all users of q_node with the original quantized tensor
127+
for user in q_node.users.copy():
128+
user.replace_input_with(q_node, original_q_tensor)
129+
130+
self._safe_remove_node(q_node, graph)
131+
self._safe_remove_node(node, graph)
132+
self._safe_remove_node(dq_node, graph)
133+
return True
134+
135+
def call(self, graph_module: torch.fx.GraphModule):
136+
graph = graph_module.graph
137+
original_nodes = list(graph.nodes)
138+
139+
for node in original_nodes:
140+
if len(node.all_input_nodes) == 0:
141+
continue
142+
143+
# Only process to_copy nodes
144+
if node.target != exir_ops.edge.aten._to_copy.default:
145+
continue
146+
147+
if is_dequant(node.args[0]):
148+
self._try_remove_quantized_redundant_to_copy(node, graph)
149+
else:
150+
self._try_remove_regular_redundant_to_copy(node, graph)
151+
152+
graph_module.recompile()
153+
154+
# Since we are overriding "call", we need to call the parent's "call"
155+
# to retrace the graph and regenerate metadata
156+
graph_module = super().call(graph_module).graph_module
157+
158+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)