Skip to content

Commit 5922f58

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add view_copy/static_reshape support to XNNPACK delegate (#7959)
Summary: Add support for delegating view_copy in the XNNPACK delegate via the XNN static_reshape operator. This includes support for up to one dynamic dimension. It also includes conditional support for NHWC, so long as batch and channel are fixed (if batch or channel is modified, the reshape has to be done in the native dim order). Test Plan: I've added test coverage for view_copy, including dynamic shape support and practitioner constraints, to backends/xnnpack/test/ops/test_view_copy.py.If also added additional tests to cover the new view logic in channels_last_tagged_reshape_pass. Differential Revision: D68691788 Pulled By: GregoryComer
1 parent 8af8252 commit 5922f58

File tree

8 files changed

+574
-11
lines changed

8 files changed

+574
-11
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_tagged_as_implicit_q_dq,
1616
tag_as_implicit_q_dq,
1717
)
18-
from executorch.backends.xnnpack.utils.utils import is_param_node
18+
from executorch.backends.xnnpack.utils.utils import get_input_node, is_param_node
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.pass_base import PassResult
2121

@@ -161,6 +161,11 @@ def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
161161
return node.target in self.memory_sensitive_ops_nhwc
162162

163163
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
164+
if (
165+
node.target == exir_ops.edge.aten.view_copy.default
166+
):
167+
return True
168+
164169
return node.target in self.memory_sensitive_ops_nchw
165170

166171
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@
5656
op_sub,
5757
op_tanh,
5858
op_to_copy,
59+
op_view_copy,
5960
)

backends/xnnpack/operators/op_skip_ops.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ class OpTCopyDefault(OpSkipOps):
5959
target = "aten.t_copy.default"
6060

6161

62-
@register_node_visitor
63-
class OpViewCopyDefault(OpSkipOps):
64-
"""
65-
currently, do nothing if node is view_copy.default
66-
need to handle this later on, currently view it as one of skip ops
67-
"""
68-
69-
target = "aten.view_copy.default"
70-
71-
7262
@register_node_visitor
7363
class OpSymSizeInt(OpSkipOps):
7464
"""
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Dict
10+
11+
import torch
12+
from executorch.backends.xnnpack.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
17+
XNNGraph,
18+
XNNStaticReshape,
19+
XNode,
20+
)
21+
from executorch.backends.xnnpack.utils.utils import (
22+
check_or_raise,
23+
get_input_node,
24+
PERM_NCHW_TO_NHWC,
25+
)
26+
27+
28+
@register_node_visitor
29+
class ViewCopyVisitor(NodeVisitor):
30+
target = "aten.view_copy.default"
31+
32+
def __init__(self, *args) -> None:
33+
super().__init__(*args)
34+
35+
def define_node(
36+
self,
37+
node: torch.fx.Node,
38+
xnn_graph: XNNGraph,
39+
vals_to_ids: Dict[torch.fx.Node, int],
40+
debug_handle: int,
41+
) -> None:
42+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
43+
44+
input_node = get_input_node(node, 0)
45+
46+
# input
47+
input_id = vals_to_ids[input_node]
48+
49+
# output
50+
output_id = vals_to_ids[node]
51+
52+
# input shape
53+
check_or_raise(
54+
"val" in input_node.meta,
55+
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
56+
)
57+
58+
# output shape
59+
check_or_raise(
60+
"val" in node.meta,
61+
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
62+
)
63+
64+
new_shape = node.args[1]
65+
check_or_raise(
66+
all(isinstance(d, int) for d in new_shape),
67+
"Symbolic reshape parameter is not supported in XNNStaticReshape",
68+
)
69+
70+
# PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0.
71+
new_shape = tuple(d if d != -1 else 0 for d in new_shape)
72+
73+
# Handle NCHW dim order - if this op is in NCHW order, we need to permute the
74+
# view shape correspondingly.
75+
if "XNN_NHWC_NODE" in node.meta:
76+
check_or_raise(len(new_shape) == 4, "Invalid NCHW shape")
77+
new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)]
78+
79+
num_dynamic_dims = sum(1 for d in new_shape if d == 0)
80+
81+
check_or_raise(
82+
num_dynamic_dims <= 1,
83+
"XNNPACK reshape only supports 1 dynamic dimension.",
84+
)
85+
86+
ser_node = XNode(
87+
xnode_union=XNNStaticReshape(
88+
num_dims=len(new_shape),
89+
new_shape=new_shape,
90+
input_id=input_id,
91+
output_id=output_id,
92+
flags=0,
93+
),
94+
debug_handle=debug_handle,
95+
)
96+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TanhConfig,
5656
ToDimOrderCopyConfig,
5757
UpsampleBilinear2dConfig,
58+
ViewCopyConfig,
5859
)
5960
from executorch.backends.xnnpack.partition.config.node_configs import (
6061
BatchNormConfig,
@@ -115,6 +116,7 @@
115116
SquareRootConfig,
116117
SubConfig,
117118
UpsampleBilinear2dConfig,
119+
ViewCopyConfig,
118120
# Quant/Dequant Op Configs
119121
QuantizedPerTensorConfig,
120122
DeQuantizedPerTensorConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,38 @@ class ExpConfig(GenericNodePartitionerConfig):
379379
def supported_precision_types(self) -> List[ConfigPrecisionType]:
380380
return [ConfigPrecisionType.FP32]
381381

382+
class ViewCopyConfig(GenericNodePartitionerConfig):
383+
target_name = "view_copy.default"
384+
385+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
386+
return [ConfigPrecisionType.FP32]
387+
388+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
389+
"""
390+
XNNPACK's static_reshape only supports 1 dynamic dimension.
391+
"""
392+
if not self.check_common_constraints(node, ep):
393+
return False
394+
395+
new_shape = node.args[1]
396+
397+
# Check for symbolic dims. They aren't lowerable to XNNPACK currently.
398+
symbolic_dim_indices = [
399+
i for i, d in enumerate(new_shape) if not isinstance(d, int)
400+
]
401+
if not all(isinstance(n, int) for n in new_shape):
402+
why(node, reason=f"Symbolic reshape is not supported. Output shape is {new_shape} and dims at {symbolic_dim_indices} are symbolic.")
403+
return False
404+
405+
dynamic_dim_indices = [
406+
i for i, d in enumerate(new_shape) if d == -1
407+
]
408+
if len(dynamic_dim_indices) > 1:
409+
why(node, reason=f"Only a single inferred dimension is supported. Output shape is {new_shape} and dims {dynamic_dim_indices} are inferred.")
410+
return False
411+
412+
return True
413+
382414

383415
class FloorConfig(GenericNodePartitionerConfig):
384416
target_name = "floor.default"

0 commit comments

Comments
 (0)