Skip to content

Commit f31c5d6

Browse files
committed
Update base for Update on "[ET-VK] Adding push constant and ubo verison of select and slice ops to improve memory and performance."
Adding push constant and ubo verison of select and slice ops to improve memory and performance. * Updated `transfer_buffer.yaml` and `transfer_texture.yaml` to include `UBO_PARAMS` parameter and generate variants for `select` and `slice` ops with UBO parameters. * Updated `transfer.glsl` to generate ubo and push constant versions of `select` and `slice` ops with UBO parameters. Differential Revision: [D78095262](https://our.internmc.facebook.com/intern/diff/D78095262/) [ghstack-poisoned]
2 parents 89906b6 + 1540659 commit f31c5d6

File tree

15 files changed

+1084
-48
lines changed

15 files changed

+1084
-48
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2424
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
2525
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
26+
from executorch.backends.xnnpack._passes.remove_redundant_copy_pass import (
27+
RemoveRedundantCopyPass,
28+
)
2629
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
2730

2831
from executorch.exir.pass_base import ExportPass
@@ -65,6 +68,7 @@ def __init__(
6568
Conv1dUnsqueezePass,
6669
PReLUReshapePass,
6770
ChannelsLastTaggedReshapePass,
71+
RemoveRedundantCopyPass,
6872
]
6973
else:
7074
self.passes = passes

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from enum import Enum
78
from typing import Optional, Tuple
89

910
import torch
@@ -19,6 +20,11 @@
1920
from executorch.exir.pass_base import PassResult
2021

2122

23+
class InputDimOrder(Enum):
24+
NCHW = 1
25+
NHWC = 2
26+
27+
2228
# TODO(T151254305) use subgraph_rewriter
2329
class ChannelsLastTaggedReshapePass(XNNPACKPass):
2430
"""
@@ -83,17 +89,49 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
8389
# is done
8490
PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"
8591

86-
def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
92+
@staticmethod
93+
def mark_as_nhwc_node(node: torch.fx.Node) -> None:
8794
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = True
8895

89-
def mark_as_nchw_node(self, node: torch.fx.Node) -> None:
96+
@staticmethod
97+
def mark_as_nchw_node(node: torch.fx.Node) -> None:
9098
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = False
9199

92-
def is_nhwc_node(self, node: torch.fx.Node) -> bool:
100+
def tag_node(self, node: torch.fx.Node) -> None:
101+
if node.kwargs["memory_format"] == torch.channels_last:
102+
self.mark_as_nhwc_node(node)
103+
else:
104+
self.mark_as_nchw_node(node)
105+
106+
@staticmethod
107+
def is_nhwc_node(node: torch.fx.Node) -> bool:
108+
if is_dequant(node) and len(node.all_input_nodes) > 0:
109+
quantize_node = node.args[0]
110+
if len(quantize_node.all_input_nodes) > 0:
111+
actual_node = quantize_node.args[0]
112+
if actual_node.op == "placeholder":
113+
return not actual_node.meta["val"][0].is_contiguous()
114+
else:
115+
return actual_node.meta.get(
116+
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
117+
)
118+
93119
return node.meta.get(ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False)
94120

95-
def is_nchw_node(self, node: torch.fx.Node) -> bool:
96-
return not self.is_nhwc_node(node)
121+
@staticmethod
122+
def is_nchw_node(node: torch.fx.Node) -> bool:
123+
if is_dequant(node) and len(node.all_input_nodes) > 0:
124+
quantize_node = node.args[0]
125+
if len(quantize_node.all_input_nodes) > 0:
126+
actual_node = quantize_node.args[0]
127+
if actual_node.op == "placeholder":
128+
return actual_node.meta["val"][0].is_contiguous()
129+
else:
130+
return not actual_node.meta.get(
131+
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
132+
)
133+
134+
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
97135

98136
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
99137
return node.target in self.memory_sensitive_ops_nhwc
@@ -111,7 +149,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
111149
is_nchw_constant = (
112150
is_param_node(self.exported_program, node)
113151
and (ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in node.meta)
114-
and (self.is_nchw_node(node))
152+
and (ChannelsLastTaggedReshapePass.is_nchw_node(node))
115153
)
116154
return is_4d and not is_nchw_constant
117155

@@ -273,6 +311,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
273311
# in that case
274312
self.make_partners(original_input, copy_node)
275313

314+
def input_dim_order(
315+
self, input_node: torch.fx.Node, input_order: InputDimOrder
316+
) -> bool:
317+
if input_node.op == "placeholder":
318+
return (
319+
input_node.meta["val"].is_contiguous()
320+
if input_order == InputDimOrder.NCHW
321+
else not input_node.meta["val"].is_contiguous()
322+
)
323+
else:
324+
return (
325+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
326+
if input_order == InputDimOrder.NCHW
327+
else ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
328+
)
329+
276330
def input_to_nhwc(
277331
self,
278332
graph_module: torch.fx.GraphModule,
@@ -282,7 +336,7 @@ def input_to_nhwc(
282336
if is_param_node(self.exported_program, input_node):
283337
if (
284338
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
285-
and self.is_nchw_node(input_node)
339+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
286340
):
287341
# This constant data tensor has been used somewhere else
288342
# in NCHW format so we can't use it here in NHWC format
@@ -296,7 +350,10 @@ def input_to_nhwc(
296350
if input_node.op == "placeholder":
297351
if not input_node.meta["val"][0].is_contiguous():
298352
return
299-
elif self.is_nhwc_node(input_node):
353+
elif ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
354+
return
355+
356+
if self.input_dim_order(input_node, InputDimOrder.NHWC):
300357
return
301358

302359
if not self.can_be_converted_to_nhwc(input_node):
@@ -326,6 +383,8 @@ def input_to_nhwc(
326383
args=(input_node,),
327384
memory_format=torch.channels_last,
328385
)
386+
# Use static method for consistency
387+
ChannelsLastTaggedReshapePass.mark_as_nhwc_node(input_node_nhwc)
329388

330389
if is_dynamic_input:
331390
# Replace downstream input_nodes with NHWC node
@@ -348,7 +407,7 @@ def input_to_nchw(
348407
if is_param_node(self.exported_program, input_node):
349408
if (
350409
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
351-
and self.is_nhwc_node(input_node)
410+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
352411
):
353412
# This constant data tensor has been used somewhere else
354413
# in NHWC format so we can't use it here in NCHW format
@@ -363,7 +422,10 @@ def input_to_nchw(
363422
if input_node.op == "placeholder":
364423
if input_node.meta["val"].is_contiguous():
365424
return
366-
elif self.is_nchw_node(input_node):
425+
elif ChannelsLastTaggedReshapePass.is_nchw_node(input_node):
426+
return
427+
428+
if self.input_dim_order(input_node, InputDimOrder.NCHW):
367429
return
368430

369431
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
@@ -380,6 +442,7 @@ def input_to_nchw(
380442
args=(input_node,),
381443
memory_format=torch.contiguous_format,
382444
)
445+
ChannelsLastTaggedReshapePass.mark_as_nchw_node(input_node_nchw)
383446

384447
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
385448
graph_module=graph_module,
@@ -393,7 +456,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
393456
original_nodes = list(graph.nodes)
394457
for node in original_nodes:
395458
if len(node.all_input_nodes) == 0:
396-
# This node has no inputs so we don't need to change anything
459+
# This node has no inputs so we don't need to change anything, but still need to tag input nodes
460+
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
461+
if node.meta["val"].is_contiguous():
462+
self.mark_as_nchw_node(node)
463+
else:
464+
self.mark_as_nhwc_node(node)
397465
continue
398466

399467
# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
@@ -407,10 +475,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
407475
elif self.requires_nhwc_input(node):
408476
# Nodes which enter this branch are ones that require their
409477
# first input to be nhwc. This makes this node's output nhwc too
410-
411478
self.input_to_nhwc(graph_module, node.args[0], node)
412-
for input_node in node.all_input_nodes:
413-
if input_node.op == "placeholder" and self.is_nhwc_node(input_node):
479+
for input_node in node.all_input_nodes[1:]:
480+
if (
481+
input_node.op == "placeholder"
482+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
483+
):
414484
raise AssertionError(
415485
f"Expected {input_node} to be NCHW in channels last reshape pass"
416486
)
@@ -419,11 +489,14 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
419489
# The node requires nchw inputs
420490
for input_node in node.all_input_nodes:
421491
self.input_to_nchw(graph_module, input_node, node)
492+
elif node.target == exir_ops.edge.aten._to_copy.default:
493+
self.tag_node(node)
422494
else:
423495
# The node can have inputs in any format (but all must be the
424496
# same format)
425497
is_or_isnt_nhwc_node = [
426-
self.is_nhwc_node(input_node) for input_node in node.all_input_nodes
498+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
499+
for input_node in node.all_input_nodes
427500
]
428501
if all(is_or_isnt_nhwc_node):
429502
# All inputs are nhwc so this node's output is nhwc too
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
9+
ChannelsLastTaggedReshapePass,
10+
)
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
class RemoveRedundantCopyPass(XNNPACKPass):
18+
def _safe_remove_node(self, node, graph):
19+
if len(node.users) == 0:
20+
graph.erase_node(node)
21+
22+
def _try_remove_regular_redundant_to_copy(self, node, graph):
23+
"""
24+
Try to remove redundant regular to_copy operations with pattern to_copy1 -> to_copy2 with opposite memory formats
25+
"""
26+
input_node = node.args[0]
27+
28+
# Check if input is a to_copy with opposite memory format
29+
if (
30+
input_node.target == exir_ops.edge.aten._to_copy.default
31+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
32+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
33+
and len(input_node.users) == 1
34+
): # Ensure the first copy has no other users
35+
36+
# Get the original input (before the first to_copy)
37+
original_input = input_node.args[0]
38+
39+
# Replace all users of the second to_copy with the original input
40+
for user in node.users.copy():
41+
user.replace_input_with(node, original_input)
42+
43+
# Remove both to_copy nodes
44+
self._safe_remove_node(node, graph)
45+
self._safe_remove_node(input_node, graph)
46+
47+
return True
48+
elif (
49+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
50+
and ChannelsLastTaggedReshapePass.is_nhwc_node(node)
51+
) or (
52+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
53+
and ChannelsLastTaggedReshapePass.is_nchw_node(node)
54+
):
55+
# Replace all users of the second to_copy with the original input
56+
for user in node.users.copy():
57+
user.replace_input_with(node, input_node)
58+
self._safe_remove_node(node, graph)
59+
return True
60+
61+
return False
62+
63+
def _try_remove_quantized_redundant_to_copy(self, node, graph):
64+
"""
65+
Try to remove redundant to_copy operations in quantized graphs with pattern dq1 -> to_copy1 -> q1 -> dq2 -> to_copy2 -> q2
66+
"""
67+
# Check if this to_copy is followed by a quantize node
68+
if len(node.users) != 1:
69+
return False
70+
q_node = next(iter(node.users))
71+
if not is_quant(q_node):
72+
return False
73+
74+
# Check if this to_copy is preceded by a dequantize node
75+
dq_node = node.args[0]
76+
if not is_dequant(dq_node):
77+
return False
78+
79+
# Get the input to the dequantize node
80+
if len(dq_node.all_input_nodes) != 1:
81+
return False
82+
83+
prev_q_node = dq_node.args[0]
84+
85+
# Check if there's another dequantize -> to_copy -> quantize chain
86+
if not is_quant(prev_q_node) or len(prev_q_node.all_input_nodes) != 1:
87+
return False
88+
89+
# Check if there's a to_copy before the previous quantize
90+
prev_to_copy = prev_q_node.args[0]
91+
if (
92+
prev_to_copy.target == exir_ops.edge.aten._to_copy.default
93+
and ChannelsLastTaggedReshapePass.is_nchw_node(prev_to_copy)
94+
!= ChannelsLastTaggedReshapePass.is_nchw_node(node)
95+
and len(prev_to_copy.users) == 1
96+
): # Ensure the first copy has no other users
97+
prev_dq_node = prev_to_copy.args[0]
98+
if not is_dequant(prev_dq_node) or len(prev_dq_node.all_input_nodes) != 1:
99+
return False
100+
101+
# Get the original input (before the first to_copy)
102+
original_input = prev_dq_node.args[0]
103+
104+
# Replace all users of the second to_copy with the original input
105+
for user in q_node.users.copy():
106+
user.replace_input_with(q_node, original_input)
107+
108+
# Remove nodes safely (only if they have no other users)
109+
self._safe_remove_node(q_node, graph)
110+
self._safe_remove_node(node, graph)
111+
self._safe_remove_node(dq_node, graph)
112+
self._safe_remove_node(prev_q_node, graph)
113+
self._safe_remove_node(prev_to_copy, graph)
114+
self._safe_remove_node(prev_dq_node, graph)
115+
elif (
116+
ChannelsLastTaggedReshapePass.is_nhwc_node(prev_to_copy)
117+
and ChannelsLastTaggedReshapePass.is_nhwc_node(node)
118+
) or (
119+
ChannelsLastTaggedReshapePass.is_nchw_node(prev_to_copy)
120+
and ChannelsLastTaggedReshapePass.is_nchw_node(node)
121+
):
122+
# Remove node and the q/dq around it only
123+
# Get the original quantized tensor (input to dq_node)
124+
original_q_tensor = dq_node.args[0]
125+
126+
# Replace all users of q_node with the original quantized tensor
127+
for user in q_node.users.copy():
128+
user.replace_input_with(q_node, original_q_tensor)
129+
130+
self._safe_remove_node(q_node, graph)
131+
self._safe_remove_node(node, graph)
132+
self._safe_remove_node(dq_node, graph)
133+
return True
134+
135+
def call(self, graph_module: torch.fx.GraphModule):
136+
graph = graph_module.graph
137+
original_nodes = list(graph.nodes)
138+
139+
for node in original_nodes:
140+
if len(node.all_input_nodes) == 0:
141+
continue
142+
143+
# Only process to_copy nodes
144+
if node.target != exir_ops.edge.aten._to_copy.default:
145+
continue
146+
147+
if is_dequant(node.args[0]):
148+
self._try_remove_quantized_redundant_to_copy(node, graph)
149+
else:
150+
self._try_remove_regular_redundant_to_copy(node, graph)
151+
152+
graph_module.recompile()
153+
154+
# Since we are overriding "call", we need to call the parent's "call"
155+
# to retrace the graph and regenerate metadata
156+
graph_module = super().call(graph_module).graph_module
157+
158+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)