Skip to content

Commit 4e29bc9

Browse files
authored
Milestone2.2: Optimize transposes in XNNPACK partition by removing redundant to_copy ops (#11316)
### Summary Optimize transposes in XNNPACK partition by adding a new remove_redundant_ops_pass that check for dim order conversion ops that cancel each other. The pass supports non-quantized conversions and also quantized graphs. In the quantized graph case, the conversion nodes and wrapping q/dq nodes will be removed. I also refactored the channels_last_tagged_reshape_pass code by modularizing some functions and adding some setter/getter functions. This change will improve speed/memory at runtime by not executing redundant to_copy ops that would be there otherwise. ### Test plan Created a TestChannelsLastTaggedReshapePass class which constructs graphs with multiple redundant to_copy ops in different positions and in quantized/non-quantized graphs. These redundant ops are either explicitly stated or generated via other passes. I asserted their removal after the passes finished.
1 parent 3950872 commit 4e29bc9

File tree

4 files changed

+432
-15
lines changed

4 files changed

+432
-15
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2424
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
26+
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
27+
RemoveRedundantCopyPass,
28+
)
2629
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
2730

2831
from executorch.exir.pass_base import ExportPass
@@ -65,6 +68,7 @@ def __init__(
6568
Conv1dUnsqueezePass,
6669
PReLUReshapePass,
6770
ChannelsLastTaggedReshapePass,
71+
RemoveRedundantCopyPass,
6872
]
6973
else:
7074
self.passes = passes

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from enum import Enum
78
from typing import Optional, Tuple
89

910
import torch
@@ -19,6 +20,11 @@
1920
from executorch.exir.pass_base import PassResult
2021

2122

23+
class InputDimOrder(Enum):
24+
NCHW = 1
25+
NHWC = 2
26+
27+
2228
# TODO(T151254305) use subgraph_rewriter
2329
class ChannelsLastTaggedReshapePass(XNNPACKPass):
2430
"""
@@ -83,17 +89,49 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
8389
# is done
8490
PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"
8591

86-
def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
92+
@staticmethod
93+
def mark_as_nhwc_node(node: torch.fx.Node) -> None:
8794
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = True
8895

89-
def mark_as_nchw_node(self, node: torch.fx.Node) -> None:
96+
@staticmethod
97+
def mark_as_nchw_node(node: torch.fx.Node) -> None:
9098
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = False
9199

92-
def is_nhwc_node(self, node: torch.fx.Node) -> bool:
100+
def tag_node(self, node: torch.fx.Node) -> None:
101+
if node.kwargs["memory_format"] == torch.channels_last:
102+
self.mark_as_nhwc_node(node)
103+
else:
104+
self.mark_as_nchw_node(node)
105+
106+
@staticmethod
107+
def is_nhwc_node(node: torch.fx.Node) -> bool:
108+
if is_dequant(node) and len(node.all_input_nodes) > 0:
109+
quantize_node = node.args[0]
110+
if len(quantize_node.all_input_nodes) > 0:
111+
actual_node = quantize_node.args[0]
112+
if actual_node.op == "placeholder":
113+
return not actual_node.meta["val"][0].is_contiguous()
114+
else:
115+
return actual_node.meta.get(
116+
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
117+
)
118+
93119
return node.meta.get(ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False)
94120

95-
def is_nchw_node(self, node: torch.fx.Node) -> bool:
96-
return not self.is_nhwc_node(node)
121+
@staticmethod
122+
def is_nchw_node(node: torch.fx.Node) -> bool:
123+
if is_dequant(node) and len(node.all_input_nodes) > 0:
124+
quantize_node = node.args[0]
125+
if len(quantize_node.all_input_nodes) > 0:
126+
actual_node = quantize_node.args[0]
127+
if actual_node.op == "placeholder":
128+
return actual_node.meta["val"][0].is_contiguous()
129+
else:
130+
return not actual_node.meta.get(
131+
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
132+
)
133+
134+
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
97135

98136
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
99137
return node.target in self.memory_sensitive_ops_nhwc
@@ -111,7 +149,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
111149
is_nchw_constant = (
112150
is_param_node(self.exported_program, node)
113151
and (ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in node.meta)
114-
and (self.is_nchw_node(node))
152+
and (ChannelsLastTaggedReshapePass.is_nchw_node(node))
115153
)
116154
return is_4d and not is_nchw_constant
117155

@@ -273,6 +311,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
273311
# in that case
274312
self.make_partners(original_input, copy_node)
275313

314+
def input_dim_order(
315+
self, input_node: torch.fx.Node, input_order: InputDimOrder
316+
) -> bool:
317+
if input_node.op == "placeholder":
318+
return (
319+
input_node.meta["val"].is_contiguous()
320+
if input_order == InputDimOrder.NCHW
321+
else not input_node.meta["val"].is_contiguous()
322+
)
323+
else:
324+
return (
325+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
326+
if input_order == InputDimOrder.NCHW
327+
else ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
328+
)
329+
276330
def input_to_nhwc(
277331
self,
278332
graph_module: torch.fx.GraphModule,
@@ -282,7 +336,7 @@ def input_to_nhwc(
282336
if is_param_node(self.exported_program, input_node):
283337
if (
284338
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
285-
and self.is_nchw_node(input_node)
339+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
286340
):
287341
# This constant data tensor has been used somewhere else
288342
# in NCHW format so we can't use it here in NHWC format
@@ -296,7 +350,10 @@ def input_to_nhwc(
296350
if input_node.op == "placeholder":
297351
if not input_node.meta["val"][0].is_contiguous():
298352
return
299-
elif self.is_nhwc_node(input_node):
353+
elif ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
354+
return
355+
356+
if self.input_dim_order(input_node, InputDimOrder.NHWC):
300357
return
301358

302359
if not self.can_be_converted_to_nhwc(input_node):
@@ -326,6 +383,8 @@ def input_to_nhwc(
326383
args=(input_node,),
327384
memory_format=torch.channels_last,
328385
)
386+
# Use static method for consistency
387+
ChannelsLastTaggedReshapePass.mark_as_nhwc_node(input_node_nhwc)
329388

330389
if is_dynamic_input:
331390
# Replace downstream input_nodes with NHWC node
@@ -348,7 +407,7 @@ def input_to_nchw(
348407
if is_param_node(self.exported_program, input_node):
349408
if (
350409
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
351-
and self.is_nhwc_node(input_node)
410+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
352411
):
353412
# This constant data tensor has been used somewhere else
354413
# in NHWC format so we can't use it here in NCHW format
@@ -363,7 +422,10 @@ def input_to_nchw(
363422
if input_node.op == "placeholder":
364423
if input_node.meta["val"].is_contiguous():
365424
return
366-
elif self.is_nchw_node(input_node):
425+
elif ChannelsLastTaggedReshapePass.is_nchw_node(input_node):
426+
return
427+
428+
if self.input_dim_order(input_node, InputDimOrder.NCHW):
367429
return
368430

369431
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
@@ -380,6 +442,7 @@ def input_to_nchw(
380442
args=(input_node,),
381443
memory_format=torch.contiguous_format,
382444
)
445+
ChannelsLastTaggedReshapePass.mark_as_nchw_node(input_node_nchw)
383446

384447
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
385448
graph_module=graph_module,
@@ -393,7 +456,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
393456
original_nodes = list(graph.nodes)
394457
for node in original_nodes:
395458
if len(node.all_input_nodes) == 0:
396-
# This node has no inputs so we don't need to change anything
459+
# This node has no inputs so we don't need to change anything, but still need to tag input nodes
460+
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
461+
if node.meta["val"].is_contiguous():
462+
self.mark_as_nchw_node(node)
463+
else:
464+
self.mark_as_nhwc_node(node)
397465
continue
398466

399467
# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
@@ -407,10 +475,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
407475
elif self.requires_nhwc_input(node):
408476
# Nodes which enter this branch are ones that require their
409477
# first input to be nhwc. This makes this node's output nhwc too
410-
411478
self.input_to_nhwc(graph_module, node.args[0], node)
412-
for input_node in node.all_input_nodes:
413-
if input_node.op == "placeholder" and self.is_nhwc_node(input_node):
479+
for input_node in node.all_input_nodes[1:]:
480+
if (
481+
input_node.op == "placeholder"
482+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
483+
):
414484
raise AssertionError(
415485
f"Expected {input_node} to be NCHW in channels last reshape pass"
416486
)
@@ -419,11 +489,14 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
419489
# The node requires nchw inputs
420490
for input_node in node.all_input_nodes:
421491
self.input_to_nchw(graph_module, input_node, node)
492+
elif node.target == exir_ops.edge.aten._to_copy.default:
493+
self.tag_node(node)
422494
else:
423495
# The node can have inputs in any format (but all must be the
424496
# same format)
425497
is_or_isnt_nhwc_node = [
426-
self.is_nhwc_node(input_node) for input_node in node.all_input_nodes
498+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
499+
for input_node in node.all_input_nodes
427500
]
428501
if all(is_or_isnt_nhwc_node):
429502
# All inputs are nhwc so this node's output is nhwc too
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
9+
ChannelsLastTaggedReshapePass,
10+
)
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
class RemoveRedundantCopyPass(XNNPACKPass):
18+
def _safe_remove_node(self, node, graph):
19+
if len(node.users) == 0:
20+
graph.erase_node(node)
21+
22+
def _try_remove_regular_redundant_to_copy(self, node, graph):
23+
"""
24+
Try to remove redundant regular to_copy operations with pattern to_copy1 -> to_copy2 with opposite memory formats
25+
"""
26+
input_node = node.args[0]
27+
28+
# Check if input is a to_copy with opposite memory format
29+
if (
30+
input_node.target == exir_ops.edge.aten._to_copy.default
31+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
32+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
33+
and len(input_node.users) == 1
34+
): # Ensure the first copy has no other users
35+
36+
# Get the original input (before the first to_copy)
37+
original_input = input_node.args[0]
38+
39+
# Replace all users of the second to_copy with the original input
40+
for user in node.users.copy():
41+
user.replace_input_with(node, original_input)
42+
43+
# Remove both to_copy nodes
44+
self._safe_remove_node(node, graph)
45+
self._safe_remove_node(input_node, graph)
46+
47+
return True
48+
elif (
49+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
50+
and ChannelsLastTaggedReshapePass.is_nhwc_node(node)
51+
) or (
52+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
53+
and ChannelsLastTaggedReshapePass.is_nchw_node(node)
54+
):
55+
# Replace all users of the second to_copy with the original input
56+
for user in node.users.copy():
57+
user.replace_input_with(node, input_node)
58+
self._safe_remove_node(node, graph)
59+
return True
60+
61+
return False
62+
63+
def _try_remove_quantized_redundant_to_copy(self, node, graph):
64+
"""
65+
Try to remove redundant to_copy operations in quantized graphs with pattern dq1 -> to_copy1 -> q1 -> dq2 -> to_copy2 -> q2
66+
"""
67+
# Check if this to_copy is followed by a quantize node
68+
if len(node.users) != 1:
69+
return False
70+
q_node = next(iter(node.users))
71+
if not is_quant(q_node):
72+
return False
73+
74+
# Check if this to_copy is preceded by a dequantize node
75+
dq_node = node.args[0]
76+
if not is_dequant(dq_node):
77+
return False
78+
79+
# Get the input to the dequantize node
80+
if len(dq_node.all_input_nodes) != 1:
81+
return False
82+
83+
prev_q_node = dq_node.args[0]
84+
85+
# Check if there's another dequantize -> to_copy -> quantize chain
86+
if not is_quant(prev_q_node) or len(prev_q_node.all_input_nodes) != 1:
87+
return False
88+
89+
# Check if there's a to_copy before the previous quantize
90+
prev_to_copy = prev_q_node.args[0]
91+
if (
92+
prev_to_copy.target == exir_ops.edge.aten._to_copy.default
93+
and ChannelsLastTaggedReshapePass.is_nchw_node(prev_to_copy)
94+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
95+
and len(prev_to_copy.users) == 1
96+
): # Ensure the first copy has no other users
97+
prev_dq_node = prev_to_copy.args[0]
98+
if not is_dequant(prev_dq_node) or len(prev_dq_node.all_input_nodes) != 1:
99+
return False
100+
101+
# Get the original input (before the first to_copy)
102+
original_input = prev_dq_node.args[0]
103+
104+
# Replace all users of the second to_copy with the original input
105+
for user in q_node.users.copy():
106+
user.replace_input_with(q_node, original_input)
107+
108+
# Remove nodes safely (only if they have no other users)
109+
self._safe_remove_node(q_node, graph)
110+
self._safe_remove_node(node, graph)
111+
self._safe_remove_node(dq_node, graph)
112+
self._safe_remove_node(prev_q_node, graph)
113+
self._safe_remove_node(prev_to_copy, graph)
114+
self._safe_remove_node(prev_dq_node, graph)
115+
elif (
116+
ChannelsLastTaggedReshapePass.is_nhwc_node(prev_to_copy)
117+
and ChannelsLastTaggedReshapePass.is_nhwc_node(node)
118+
) or (
119+
ChannelsLastTaggedReshapePass.is_nchw_node(prev_to_copy)
120+
and ChannelsLastTaggedReshapePass.is_nchw_node(node)
121+
):
122+
# Remove node and the q/dq around it only
123+
# Get the original quantized tensor (input to dq_node)
124+
original_q_tensor = dq_node.args[0]
125+
126+
# Replace all users of q_node with the original quantized tensor
127+
for user in q_node.users.copy():
128+
user.replace_input_with(q_node, original_q_tensor)
129+
130+
self._safe_remove_node(q_node, graph)
131+
self._safe_remove_node(node, graph)
132+
self._safe_remove_node(dq_node, graph)
133+
return True
134+
135+
def call(self, graph_module: torch.fx.GraphModule):
136+
graph = graph_module.graph
137+
original_nodes = list(graph.nodes)
138+
139+
for node in original_nodes:
140+
if len(node.all_input_nodes) == 0:
141+
continue
142+
143+
# Only process to_copy nodes
144+
if node.target != exir_ops.edge.aten._to_copy.default:
145+
continue
146+
147+
if is_dequant(node.args[0]):
148+
self._try_remove_quantized_redundant_to_copy(node, graph)
149+
else:
150+
self._try_remove_regular_redundant_to_copy(node, graph)
151+
152+
graph_module.recompile()
153+
154+
# Since we are overriding "call", we need to call the parent's "call"
155+
# to retrace the graph and regenerate metadata
156+
graph_module = super().call(graph_module).graph_module
157+
158+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)