Skip to content

Commit 00f33af

Browse files
committed
Add view_copy/static_reshape support to XNNPACK delegate
1 parent e78ed83 commit 00f33af

File tree

6 files changed

+301
-10
lines changed

6 files changed

+301
-10
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,5 @@
4949
op_static_resize_bilinear_2d,
5050
op_sub,
5151
op_to_copy,
52+
op_view_copy,
5253
)

backends/xnnpack/operators/op_skip_ops.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,6 @@ class OpTCopyDefault(OpSkipOps):
7777
target = "aten.t_copy.default"
7878

7979

80-
@register_node_visitor
81-
class OpViewCopyDefault(OpSkipOps):
82-
"""
83-
currently, do nothing if node is view_copy.default
84-
need to handle this later on, currently view it as one of skip ops
85-
"""
86-
87-
target = "aten.view_copy.default"
88-
89-
9080
@register_node_visitor
9181
class OpSymSizeInt(OpSkipOps):
9282
"""
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast, Dict
8+
9+
import torch
10+
from executorch.backends.transforms import get_shape
11+
from executorch.backends.xnnpack.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
16+
XNNGraph,
17+
XNNStaticReshape,
18+
XNode,
19+
)
20+
from executorch.backends.xnnpack.utils.utils import (
21+
check_or_raise,
22+
get_input_node,
23+
PERM_NCHW_TO_NHWC,
24+
PERM_NHWC_TO_NCHW,
25+
)
26+
from torch.fx.experimental.symbolic_shapes import free_symbols
27+
28+
29+
@register_node_visitor
30+
class ViewCopyVisitor(NodeVisitor):
31+
target = "aten.view_copy.default"
32+
33+
def __init__(self, *args) -> None:
34+
super().__init__(*args)
35+
36+
def define_node(
37+
self,
38+
node: torch.fx.Node,
39+
xnn_graph: XNNGraph,
40+
vals_to_ids: Dict[torch.fx.Node, int],
41+
debug_handle: int,
42+
) -> None:
43+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
44+
45+
input_node = get_input_node(node, 0)
46+
47+
# input
48+
input_id = vals_to_ids[input_node]
49+
50+
# output
51+
output_id = vals_to_ids[node]
52+
53+
# input shape
54+
check_or_raise(
55+
"val" in input_node.meta,
56+
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
57+
)
58+
input_shape = input_node.meta["val"].shape
59+
60+
# output shape
61+
check_or_raise(
62+
"val" in node.meta,
63+
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
64+
)
65+
66+
new_shape = node.args[1]
67+
check_or_raise(
68+
all(isinstance(d, int) for d in new_shape),
69+
"Symbol reshape parameter is not supported in XNNStaticReshape",
70+
)
71+
72+
# PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0.
73+
new_shape = tuple(d if d != -1 else 0 for d in new_shape)
74+
75+
num_dynamic_dims = sum(1 for d in new_shape if d == 0)
76+
77+
check_or_raise(
78+
num_dynamic_dims <= 1,
79+
"XNNPACK reshape only supports 1 dynamic dimension.",
80+
)
81+
82+
ser_node = XNode(
83+
xnode_union=XNNStaticReshape(
84+
num_dims=len(new_shape),
85+
new_shape=new_shape,
86+
input_id=input_id,
87+
output_id=output_id,
88+
flags=0,
89+
),
90+
debug_handle=debug_handle,
91+
)
92+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
SquareRootConfig,
4848
SubConfig,
4949
UpsampleBilinear2dConfig,
50+
ViewCopyConfig,
5051
)
5152
from executorch.backends.xnnpack.partition.config.node_configs import (
5253
BatchNormConfig,
@@ -100,6 +101,7 @@
100101
SquareRootConfig,
101102
SubConfig,
102103
UpsampleBilinear2dConfig,
104+
ViewCopyConfig,
103105
# Quant/Dequant Op Configs
104106
QuantizedPerTensorConfig,
105107
DeQuantizedPerTensorConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,32 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
336336
return torch.ops.aten.upsample_bilinear2d.vec
337337

338338

339+
class ViewCopyConfig(GenericNodePartitionerConfig):
340+
target_name = "view_copy.default"
341+
342+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
343+
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
344+
345+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
346+
"""
347+
XNNPACK's static_reshape only supports 1 dynamic dimension
348+
"""
349+
if not self.check_common_constraints(node, ep):
350+
return False
351+
352+
new_shape = node.args[1]
353+
if not all(isinstance(n, int) for n in new_shape):
354+
why(node, reason="symbolic reshape is not supported")
355+
return False
356+
357+
dynamic_dim_count = sum(1 for d in new_shape if d == -1)
358+
if dynamic_dim_count > 1:
359+
why(node, reason="only a single dynamic dimension is supported")
360+
return False
361+
362+
return True
363+
364+
339365
class FloorConfig(GenericNodePartitionerConfig):
340366
target_name = "floor.default"
341367

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Export, Tester
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from torch.export import Dim
13+
14+
15+
class TestViewCopy(unittest.TestCase):
16+
class View(torch.nn.Module):
17+
def __init__(self, new_shape):
18+
super().__init__()
19+
self.new_shape = new_shape
20+
21+
def forward(self, x):
22+
z = x.view(self.new_shape)
23+
return z
24+
25+
def test_fp16_view_copy(self):
26+
inputs = (torch.randn(4, 4).to(torch.float16),)
27+
(
28+
Tester(self.View((2, 8)), inputs)
29+
.export()
30+
.check_node_count({torch.ops.aten.view.default: 1})
31+
.to_edge_transform_and_lower()
32+
.check_node_count(
33+
{
34+
torch.ops.higher_order.executorch_call_delegate: 1,
35+
exir_ops.edge.aten.view_copy.default: 0,
36+
}
37+
)
38+
.to_executorch()
39+
.serialize()
40+
.run_method_and_compare_outputs()
41+
)
42+
43+
def test_fp32_view_copy(self):
44+
inputs = (torch.randn(4, 4),)
45+
(
46+
Tester(self.View((2, 8)), inputs)
47+
.export()
48+
.check_node_count({torch.ops.aten.view.default: 1})
49+
.to_edge_transform_and_lower()
50+
.check_node_count(
51+
{
52+
torch.ops.higher_order.executorch_call_delegate: 1,
53+
exir_ops.edge.aten.view_copy.default: 0,
54+
}
55+
)
56+
.to_executorch()
57+
.serialize()
58+
.run_method_and_compare_outputs()
59+
)
60+
61+
def test_fp32_view_copy_inferred_dim(self):
62+
inputs = (torch.randn(4, 4),)
63+
(
64+
Tester(self.View((-1, 8)), inputs)
65+
.export()
66+
.check_node_count({torch.ops.aten.view.default: 1})
67+
.to_edge_transform_and_lower()
68+
.check_node_count(
69+
{
70+
torch.ops.higher_order.executorch_call_delegate: 1,
71+
exir_ops.edge.aten.view_copy.default: 0,
72+
}
73+
)
74+
.to_executorch()
75+
.serialize()
76+
.run_method_and_compare_outputs()
77+
)
78+
79+
def test_fp32_view_copy_dynamic_shape_first_dim(self):
80+
inputs = (torch.randn(4, 4),)
81+
test_inputs = (
82+
(torch.randn(2, 4),),
83+
(torch.randn(10, 4),),
84+
)
85+
dynamic_shapes = {
86+
"x": {0: Dim("x", min=1, max=10)},
87+
}
88+
tester = (
89+
Tester(self.View((-1, 2)), inputs)
90+
.export(Export(dynamic_shapes=dynamic_shapes))
91+
.check_node_count({torch.ops.aten.view.default: 1})
92+
.to_edge_transform_and_lower()
93+
.check_node_count(
94+
{
95+
torch.ops.higher_order.executorch_call_delegate: 1,
96+
exir_ops.edge.aten.view_copy.default: 0,
97+
}
98+
)
99+
.to_executorch()
100+
.serialize()
101+
.run_method_and_compare_outputs()
102+
)
103+
104+
for test_input in test_inputs:
105+
tester.run_method_and_compare_outputs(inputs=test_input)
106+
107+
def test_fp32_view_copy_dynamic_shape_last_dim(self):
108+
inputs = (torch.randn(2, 4, 4),)
109+
test_inputs = (
110+
(torch.randn(2, 4, 2),),
111+
(torch.randn(2, 4, 10),),
112+
)
113+
dynamic_shapes = {
114+
"x": {2: Dim("x", min=1, max=10) * 2},
115+
}
116+
tester = (
117+
Tester(self.View((-1, 4, 2)), inputs)
118+
.export(Export(dynamic_shapes=dynamic_shapes))
119+
.check_node_count({torch.ops.aten.view.default: 1})
120+
.to_edge_transform_and_lower()
121+
.check_node_count(
122+
{
123+
torch.ops.higher_order.executorch_call_delegate: 1,
124+
exir_ops.edge.aten.view_copy.default: 0,
125+
}
126+
)
127+
.to_executorch()
128+
.serialize()
129+
.run_method_and_compare_outputs()
130+
)
131+
132+
for test_input in test_inputs:
133+
tester.run_method_and_compare_outputs(inputs=test_input)
134+
135+
def test_fp32_view_copy_unsupported_dynamism(self):
136+
class SymbolicView(torch.nn.Module):
137+
def forward(self, x):
138+
return x.view(x.shape[0] // 2, x.shape[1] * 2)
139+
140+
inputs = (torch.randn(4, 4),)
141+
dynamic_shapes = {
142+
"x": {1: Dim("x", min=1, max=10) * 2},
143+
}
144+
(
145+
Tester(SymbolicView(), inputs)
146+
.export(Export(dynamic_shapes=dynamic_shapes))
147+
.check_node_count({torch.ops.aten.view.default: 1})
148+
.to_edge_transform_and_lower()
149+
.check_node_count(
150+
{ # Expect no delegation.
151+
torch.ops.higher_order.executorch_call_delegate: 0,
152+
exir_ops.edge.aten.view_copy.default: 1,
153+
}
154+
)
155+
.to_executorch()
156+
.serialize()
157+
.run_method_and_compare_outputs()
158+
)
159+
160+
def test_fp32_view_copy_static_symbolic_arg(self):
161+
class SymbolicView(torch.nn.Module):
162+
def forward(self, x):
163+
return x.view(x.shape[0] // 2, x.shape[1] * 2)
164+
165+
inputs = (torch.randn(4, 4),)
166+
(
167+
Tester(SymbolicView(), inputs)
168+
.export()
169+
.check_node_count({torch.ops.aten.view.default: 1})
170+
.to_edge_transform_and_lower()
171+
.check_node_count(
172+
{
173+
torch.ops.higher_order.executorch_call_delegate: 1,
174+
exir_ops.edge.aten.view_copy.default: 0,
175+
}
176+
)
177+
.to_executorch()
178+
.serialize()
179+
.run_method_and_compare_outputs()
180+
)

0 commit comments

Comments
 (0)