Skip to content

Commit 85729d2

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add view_copy/static_reshape support to XNNPACK delegate (#7959)
Summary: Add support for delegating view_copy in the XNNPACK delegate via the XNN static_reshape operator. This includes support for up to one dynamic dimension. It also includes conditional support for NHWC, so long as batch and channel are fixed (if batch or channel is modified, the reshape has to be done in the native dim order). Test Plan: I've added test coverage for view_copy, including dynamic shape support and practitioner constraints, to backends/xnnpack/test/ops/test_view_copy.py.If also added additional tests to cover the new view logic in channels_last_tagged_reshape_pass. Differential Revision: D68691788 Pulled By: GregoryComer
1 parent ad5a00a commit 85729d2

File tree

9 files changed

+585
-11
lines changed

9 files changed

+585
-11
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_tagged_as_implicit_q_dq,
1616
tag_as_implicit_q_dq,
1717
)
18-
from executorch.backends.xnnpack.utils.utils import is_param_node
18+
from executorch.backends.xnnpack.utils.utils import get_input_node, is_param_node
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.pass_base import PassResult
2121

@@ -89,6 +89,22 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
8989
# is done
9090
PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"
9191

92+
@staticmethod
93+
def is_view_dim_order_invariant(node: torch.fx.Node) -> bool:
94+
# View must be done in NCHW dim order if channel or batch is changed,
95+
# or if rank is not 4.
96+
in_shape = get_input_node(node, 0).meta["val"].shape
97+
out_shape = node.meta["val"].shape
98+
99+
if len(in_shape) != 4 or len(out_shape) != 4:
100+
return False
101+
102+
# Are batch and channel modified? If so, return false.
103+
if in_shape[0] != out_shape[0] or in_shape[1] != out_shape[1]:
104+
return False
105+
106+
return True
107+
92108
@staticmethod
93109
def mark_as_nhwc_node(node: torch.fx.Node) -> None:
94110
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = True
@@ -161,6 +177,13 @@ def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
161177
return node.target in self.memory_sensitive_ops_nhwc
162178

163179
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
180+
# Views depend on whether batch or channel are modified.
181+
if (
182+
node.target == exir_ops.edge.aten.view_copy.default
183+
and not self.is_view_dim_order_invariant(node)
184+
):
185+
return True
186+
164187
return node.target in self.memory_sensitive_ops_nchw
165188

166189
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,5 @@
5353
op_sub,
5454
op_tanh,
5555
op_to_copy,
56+
op_view_copy,
5657
)

backends/xnnpack/operators/op_skip_ops.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ class OpTCopyDefault(OpSkipOps):
5959
target = "aten.t_copy.default"
6060

6161

62-
@register_node_visitor
63-
class OpViewCopyDefault(OpSkipOps):
64-
"""
65-
currently, do nothing if node is view_copy.default
66-
need to handle this later on, currently view it as one of skip ops
67-
"""
68-
69-
target = "aten.view_copy.default"
70-
71-
7262
@register_node_visitor
7363
class OpSymSizeInt(OpSkipOps):
7464
"""
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Dict
10+
11+
import torch
12+
from executorch.backends.xnnpack.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
17+
XNNGraph,
18+
XNNStaticReshape,
19+
XNode,
20+
)
21+
from executorch.backends.xnnpack.utils.utils import (
22+
check_or_raise,
23+
get_input_node,
24+
PERM_NCHW_TO_NHWC,
25+
)
26+
27+
28+
@register_node_visitor
29+
class ViewCopyVisitor(NodeVisitor):
30+
target = "aten.view_copy.default"
31+
32+
def __init__(self, *args) -> None:
33+
super().__init__(*args)
34+
35+
def define_node(
36+
self,
37+
node: torch.fx.Node,
38+
xnn_graph: XNNGraph,
39+
vals_to_ids: Dict[torch.fx.Node, int],
40+
debug_handle: int,
41+
) -> None:
42+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
43+
44+
input_node = get_input_node(node, 0)
45+
46+
# input
47+
input_id = vals_to_ids[input_node]
48+
49+
# output
50+
output_id = vals_to_ids[node]
51+
52+
# input shape
53+
check_or_raise(
54+
"val" in input_node.meta,
55+
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
56+
)
57+
58+
# output shape
59+
check_or_raise(
60+
"val" in node.meta,
61+
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
62+
)
63+
64+
new_shape = node.args[1]
65+
check_or_raise(
66+
all(isinstance(d, int) for d in new_shape),
67+
"Symbolic reshape parameter is not supported in XNNStaticReshape",
68+
)
69+
70+
# PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0.
71+
new_shape = tuple(d if d != -1 else 0 for d in new_shape)
72+
73+
# Handle NCHW dim order - if this op is in NCHW order, we need to permute the
74+
# view shape correspondingly.
75+
if "XNN_NHWC_NODE" in node.meta:
76+
check_or_raise(len(new_shape) == 4, "Invalid NCHW shape")
77+
new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)]
78+
79+
num_dynamic_dims = sum(1 for d in new_shape if d == 0)
80+
81+
check_or_raise(
82+
num_dynamic_dims <= 1,
83+
"XNNPACK reshape only supports 1 dynamic dimension.",
84+
)
85+
86+
ser_node = XNode(
87+
xnode_union=XNNStaticReshape(
88+
num_dims=len(new_shape),
89+
new_shape=new_shape,
90+
input_id=input_id,
91+
output_id=output_id,
92+
flags=0,
93+
),
94+
debug_handle=debug_handle,
95+
)
96+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
TanhConfig,
5454
ToDimOrderCopyConfig,
5555
UpsampleBilinear2dConfig,
56+
ViewCopyConfig,
5657
)
5758
from executorch.backends.xnnpack.partition.config.node_configs import (
5859
BatchNormConfig,
@@ -112,6 +113,7 @@
112113
SquareRootConfig,
113114
SubConfig,
114115
UpsampleBilinear2dConfig,
116+
ViewCopyConfig,
115117
# Quant/Dequant Op Configs
116118
QuantizedPerTensorConfig,
117119
DeQuantizedPerTensorConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,31 @@ class ExpConfig(GenericNodePartitionerConfig):
379379
def supported_precision_types(self) -> List[ConfigPrecisionType]:
380380
return [ConfigPrecisionType.FP32]
381381

382+
class ViewCopyConfig(GenericNodePartitionerConfig):
383+
target_name = "view_copy.default"
384+
385+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
386+
return [ConfigPrecisionType.FP32]
387+
388+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
389+
"""
390+
XNNPACK's static_reshape only supports 1 dynamic dimension
391+
"""
392+
if not self.check_common_constraints(node, ep):
393+
return False
394+
395+
new_shape = node.args[1]
396+
if not all(isinstance(n, int) for n in new_shape):
397+
why(node, reason="symbolic reshape is not supported")
398+
return False
399+
400+
dynamic_dim_count = sum(1 for d in new_shape if d == -1)
401+
if dynamic_dim_count > 1:
402+
why(node, reason="only a single dynamic dimension is supported")
403+
return False
404+
405+
return True
406+
382407

383408
class FloorConfig(GenericNodePartitionerConfig):
384409
target_name = "floor.default"

0 commit comments

Comments
 (0)