Skip to content

Commit 144b550

Browse files
committed
Conditionally support expand_copy in XNNPACK delegate
1 parent d886373 commit 144b550

File tree

9 files changed

+271
-0
lines changed

9 files changed

+271
-0
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
op_dynamic_quantize_ops,
2424
op_elu,
2525
op_exp,
26+
op_expand_copy,
2627
op_floor,
2728
op_gelu,
2829
op_hardswish,
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.xnnpack.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15+
XNNExpandDims,
16+
XNNGraph,
17+
XNode,
18+
)
19+
from executorch.backends.xnnpack.utils.utils import get_input_node
20+
21+
22+
def check_expand_copy_constraints(node: torch.fx.Node) -> bool:
23+
"""
24+
Checks whether the given expand_copy node is delegatable to XNNPACK.
25+
XNNPACK only allows insertion of size-1 dimensions, not expanding existing
26+
dims.
27+
"""
28+
in_shape = get_input_node(node, 0).meta["val"].shape
29+
new_shape = list(node.args[1])
30+
31+
assert len(new_shape) >= len(
32+
in_shape
33+
), "Expanded shape must have rank >= input rank."
34+
35+
# Check new leading dims (if any). They must be of size 1.
36+
new_leading_dims_count = len(new_shape) - len(in_shape)
37+
for i in range(new_leading_dims_count):
38+
if new_shape[i] != 1:
39+
return False
40+
41+
# Check existing dims. PyTorch expand semantics don't allow for dim insertion other
42+
# than at the front, so we just need to make sure none of the dims are expanded.
43+
for i in range(len(new_shape) - new_leading_dims_count):
44+
new_shape_at_dim = new_shape[new_leading_dims_count + i]
45+
# -1 means preserve dim.
46+
if new_shape_at_dim != -1 and new_shape_at_dim != in_shape[i]:
47+
return False
48+
49+
return True
50+
51+
52+
def get_inserted_dim_indices(
53+
node: torch.fx.Node,
54+
) -> list[int]:
55+
"""
56+
Returns the indices of the inserted dimensions in the expanded shape. Assumes that
57+
the node meets the conditions checked in check_expand_copy_constraints.
58+
"""
59+
in_shape = get_input_node(node, 0).meta["val"].shape
60+
new_shape = list(node.args[1])
61+
new_dim_indices = []
62+
63+
assert len(new_shape) >= len(
64+
in_shape
65+
), "Expanded shape must have rank >= input rank."
66+
67+
# PyTorch expand semantics enforce new dim insertion only at the front.
68+
new_leading_dims_count = len(new_shape) - len(in_shape)
69+
for i in range(new_leading_dims_count):
70+
if new_shape[i] != 1:
71+
return False
72+
else:
73+
new_dim_indices.append(i)
74+
75+
return new_dim_indices
76+
77+
78+
@register_node_visitor
79+
class ExpandCopyVisitor(NodeVisitor):
80+
target = "aten.expand_copy.default"
81+
82+
def __init__(self, *args) -> None:
83+
super().__init__(*args)
84+
85+
def define_node(
86+
self,
87+
node: torch.fx.Node,
88+
xnn_graph: XNNGraph,
89+
vals_to_ids: Dict[torch.fx.Node, int],
90+
debug_handle: int,
91+
) -> None:
92+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
93+
94+
# input
95+
input_id = vals_to_ids[get_input_node(node, 0)]
96+
97+
# output
98+
output_id = vals_to_ids[node]
99+
100+
new_dim_indices = get_inserted_dim_indices(node)
101+
102+
ser_node = XNode(
103+
xnode_union=XNNExpandDims(
104+
num_new_dims=len(new_dim_indices),
105+
new_dim_indices=new_dim_indices,
106+
input_id=input_id,
107+
output_id=output_id,
108+
flags=0,
109+
),
110+
debug_handle=debug_handle,
111+
)
112+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ConstantPadConfig,
2828
DeQuantizedPerTensorConfig,
2929
DivConfig,
30+
ExpandCopyConfig,
3031
# EluConfig,
3132
ExpConfig,
3233
FloorConfig,
@@ -87,6 +88,7 @@
8788
DivConfig,
8889
# EluConfig, # Waiting for PyTorch Pin Update
8990
ExpConfig,
91+
ExpandCopyConfig,
9092
FloorConfig,
9193
GeluConfig,
9294
HardtanhConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import logging
1010
from typing import cast, List, Optional
1111

12+
import executorch.backends.xnnpack.operators.op_expand_copy as op_expand_copy
13+
1214
import numpy as np
1315
import torch
1416
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
@@ -262,6 +264,30 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
262264
return torch.ops.aten.elu.default
263265

264266

267+
class ExpandCopyConfig(GenericNodePartitionerConfig):
268+
target_name = "expand_copy.default"
269+
270+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
271+
return [ConfigPrecisionType.FP32]
272+
273+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
274+
return torch.ops.aten.expand_copy.default
275+
276+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
277+
"""
278+
Only partition expand_copy nodes that can be converted to view_copy (insertion of
279+
singleton dims).
280+
"""
281+
if not self.check_common_constraints(node, ep):
282+
return False
283+
284+
# Explicit false check here avoids non partitioning identity expand_copy.
285+
if not op_expand_copy.check_expand_copy_constraints(node):
286+
why(node, reason="only insertion of singleton dims is supported")
287+
return False
288+
return True
289+
290+
265291
class SoftmaxConfig(GenericNodePartitionerConfig):
266292
target_name = "_softmax.default"
267293

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,35 @@ Error defineStaticReshapeNode(
11301130
return Error::Ok;
11311131
}
11321132

1133+
Error defineExpandDimsNode(
1134+
xnn_subgraph_t subgraph_ptr,
1135+
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1136+
const NodePtr node,
1137+
const fb_xnnpack::XNNGraph* graph) noexcept {
1138+
MAYBE_UNUSED(graph);
1139+
1140+
auto graph_node = node->xnode_union_as_XNNExpandDims();
1141+
1142+
// Get tensor dims, we need to convert the uint32_t* to size_t*
1143+
std::vector<size_t> dims_data =
1144+
flatbufferDimsToVector(graph_node->new_dim_indices());
1145+
xnn_status status = xnn_define_static_expand_dims(
1146+
subgraph_ptr,
1147+
graph_node->num_new_dims(),
1148+
dims_data.data(),
1149+
remapped_ids.at(graph_node->input_id()),
1150+
remapped_ids.at(graph_node->output_id()),
1151+
graph_node->flags());
1152+
ET_CHECK_OR_RETURN_ERROR(
1153+
status == xnn_status_success,
1154+
Internal,
1155+
"Failed to create expand_dims node %i with code: %s",
1156+
node->debug_handle(),
1157+
xnn_status_to_string(status));
1158+
1159+
return Error::Ok;
1160+
}
1161+
11331162
/*
11341163
Define serialized maxpool2d node into the subgraph, using the remapped ids
11351164
to map the serialized ids, to the new ids generated when defining the
@@ -1784,6 +1813,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
17841813
_DEFINE(Convert)
17851814
_DEFINE(GlobalAvgPooling2d)
17861815
_DEFINE(StaticReshape)
1816+
_DEFINE(ExpandDims)
17871817
_DEFINE(ArgMaxPooling2d)
17881818
_DEFINE(Concatenate2)
17891819
_DEFINE(Concatenate3)

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ union XNodeUnion {
158158
XNNExp: _XNNNode1x1,
159159
XNNSin: _XNNNode1x1,
160160
XNNCopy: _XNNNode1x1,
161+
XNNExpandDims,
161162
}
162163

163164
union XValueUnion {
@@ -296,6 +297,14 @@ table XNNStaticReshape {
296297
flags: uint;
297298
}
298299

300+
table XNNExpandDims {
301+
num_new_dims:uint;
302+
new_dim_indices:[uint];
303+
input_id: uint;
304+
output_id: uint;
305+
flags: uint;
306+
}
307+
299308
table XNNStaticSlice {
300309
num_dims:uint;
301310
offsets:[uint];

backends/xnnpack/serialization/schema.fbs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ union XNodeUnion {
154154
XNNExp: _XNNNode1x1,
155155
XNNSin: _XNNNode1x1,
156156
XNNCopy: _XNNNode1x1,
157+
XNNExpandDims,
157158
}
158159

159160
union XValueUnion {
@@ -292,6 +293,14 @@ table XNNStaticReshape {
292293
flags: uint;
293294
}
294295

296+
table XNNExpandDims {
297+
num_new_dims:uint;
298+
new_dim_indices:[uint];
299+
input_id: uint;
300+
output_id: uint;
301+
flags: uint;
302+
}
303+
295304
table XNNStaticSlice {
296305
num_dims:uint;
297306
offsets:[uint];

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,15 @@ class XNNScaledDotProductAttention:
368368
flags: int
369369

370370

371+
@dataclass
372+
class XNNExpandDims:
373+
num_new_dims: int
374+
new_dim_indices: List[int]
375+
input_id: int
376+
output_id: int
377+
flags: int
378+
379+
371380
XNodeUnion = Union[
372381
XNNAdd,
373382
XNNFullyConnected,
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
14+
class TestExpand(unittest.TestCase):
15+
class Expand(torch.nn.Module):
16+
def __init__(self, out_shape):
17+
super().__init__()
18+
self.out_shape = out_shape
19+
20+
def forward(self, x):
21+
return x.expand(self.out_shape)
22+
23+
def test_fp32_insert_dim(self):
24+
inputs = (torch.randn(8, 12),)
25+
new_shapes = (
26+
(1, 8, 12),
27+
(1, 1, 8, 12),
28+
(8, -1),
29+
(-1, 12),
30+
(1, -1, -1),
31+
(1, 1, 8, -1),
32+
)
33+
34+
for new_shape in new_shapes:
35+
(
36+
Tester(self.Expand(new_shape), tuple(inputs))
37+
.export()
38+
.check_node_count({torch.ops.aten.expand.default: 1})
39+
.to_edge_transform_and_lower()
40+
.check_node_count(
41+
{
42+
exir_ops.edge.aten.expand_copy.default: 0,
43+
exir_ops.edge.aten.view_copy.default: 0,
44+
torch.ops.higher_order.executorch_call_delegate: 1,
45+
}
46+
)
47+
.to_executorch()
48+
.run_method_and_compare_outputs()
49+
)
50+
51+
def test_fp32_unsupported_expand(self):
52+
inputs = (torch.randn(1, 8, 12),)
53+
new_shapes = (
54+
(2, 8, 12),
55+
(1, 2, 8, 12),
56+
(2, 1, 8, 12),
57+
)
58+
59+
for new_shape in new_shapes:
60+
(
61+
Tester(self.Expand(new_shape), tuple(inputs))
62+
.export()
63+
.check_node_count({torch.ops.aten.expand.default: 1})
64+
.to_edge_transform_and_lower()
65+
.check_node_count(
66+
{
67+
exir_ops.edge.aten.expand_copy.default: 1,
68+
exir_ops.edge.aten.view_copy.default: 0,
69+
}
70+
)
71+
.to_executorch()
72+
.run_method_and_compare_outputs()
73+
)

0 commit comments

Comments
 (0)