Skip to content

Commit 7ce711f

Browse files
committed
Conditionally support expand_copy in XNNPACK delegate
1 parent cf00810 commit 7ce711f

File tree

6 files changed

+244
-0
lines changed

6 files changed

+244
-0
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ConvertToUpsampleBilinear2d,
1919
)
2020
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
21+
from executorch.backends.xnnpack._passes.expand_to_view_pass import ExpandToViewPass
2122
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2223
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
2324
FuseBatchNormWithConvPass,
@@ -62,6 +63,7 @@ def __init__(
6263
ConvertToLinearPass,
6364
ConvertToSDPAPass,
6465
ConstPropPass,
66+
ExpandToViewPass,
6567
FuseBatchNormWithConvPass,
6668
FuseActivationPass,
6769
DecomposeConcatenate,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from executorch.backends.xnnpack.utils.utils import get_input_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.WARNING)
17+
18+
19+
class ExpandToViewPass(ExportPass):
20+
"""
21+
Torch expand_copy can be used as an altenative to unsqueeze. This pass replaces expand_copy nodes
22+
that only add one or more singleton dimensions.
23+
24+
25+
Example:
26+
Before Pass:
27+
expand: "f32" = torch.ops.aten.expand_copy.default(x, (1, -1));
28+
29+
After Pass:
30+
view: "f32" = torch.ops.aten.view_copy.default(x, (1, -1));
31+
"""
32+
33+
@staticmethod
34+
def can_transform_expand_node(node: torch.fx.Node) -> bool:
35+
# The node can be converted to a view if the expand only inserts singleton dimensions and
36+
# does not modify any existing dimensions.
37+
in_shape = get_input_node(node, 0).meta["val"].shape
38+
out_shape = node.meta["val"].shape
39+
40+
i = 0 # in-shape index
41+
j = 0 # out-shape index
42+
while j < len(out_shape):
43+
if i >= len(in_shape): # Shape mismatch
44+
return False
45+
elif in_shape[i] == out_shape[j]: # Dims match
46+
i += 1
47+
j += 1
48+
elif out_shape[j] == 1: # Inserted singleton dim
49+
j += 1
50+
else: # Dim mismatch (in_shape[i] != out_shape[i])
51+
return False
52+
53+
return True
54+
55+
def call(self, graph_module: torch.fx.GraphModule):
56+
gm = graph_module
57+
for node in gm.graph.nodes:
58+
if (
59+
node.op == "call_function"
60+
and node.target == exir_ops.edge.aten.expand_copy.default
61+
and ExpandToViewPass.can_transform_expand_node(node)
62+
):
63+
with gm.graph.inserting_after(node):
64+
view_node = gm.graph.create_node(
65+
"call_function",
66+
target=exir_ops.edge.aten.view_copy.default,
67+
args=(node.args[0], node.args[1]),
68+
kwargs=node.kwargs,
69+
)
70+
71+
node.replace_all_uses_with(view_node)
72+
gm.graph.erase_node(node)
73+
74+
gm.recompile()
75+
new_gm = super().call(gm).graph_module
76+
return PassResult(new_gm, True)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ConstantPadConfig,
2626
DeQuantizedPerTensorConfig,
2727
DivConfig,
28+
ExpandCopyConfig,
2829
FloorConfig,
2930
HardswishConfig,
3031
# EluConfig,
@@ -77,6 +78,7 @@
7778
ClampConfig,
7879
DivConfig,
7980
# EluConfig, # Waiting for PyTorch Pin Update
81+
ExpandCopyConfig,
8082
FloorConfig,
8183
HardtanhConfig,
8284
HardswishConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import cast, List, Optional
1111

1212
import torch
13+
from executorch.backends.xnnpack._passes.expand_to_view_pass import ExpandToViewPass
1314
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1415
ConfigPrecisionType,
1516
XNNPartitionerConfig,
@@ -225,6 +226,29 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
225226
return torch.ops.aten.elu.default
226227

227228

229+
class ExpandCopyConfig(GenericNodePartitionerConfig):
230+
target_name = "expand_copy.default"
231+
232+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
233+
return [ConfigPrecisionType.FP32]
234+
235+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
236+
return torch.ops.aten.expand_copy.default
237+
238+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
239+
"""
240+
Only partition expand_copy nodes that can be converted to view_copy (insertion of
241+
singleton dims).
242+
"""
243+
if not self.check_common_constraints(node, ep):
244+
return False
245+
246+
if not ExpandToViewPass.can_transform_expand_node(node):
247+
why(node, reason="only insertion of singleton dims is supported")
248+
return False
249+
return True
250+
251+
228252
class SoftmaxConfig(GenericNodePartitionerConfig):
229253
target_name = "_softmax.default"
230254

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
14+
class TestExpand(unittest.TestCase):
15+
class Expand(torch.nn.Module):
16+
def __init__(self, out_shape):
17+
super().__init__()
18+
self.out_shape = out_shape
19+
20+
def forward(self, x):
21+
return x.expand(self.out_shape)
22+
23+
def test_fp32_insert_dim(self):
24+
inputs = (torch.randn(8, 12),)
25+
new_shapes = (
26+
(1, 8, 12),
27+
(1, 1, 8, 12),
28+
)
29+
30+
for new_shape in new_shapes:
31+
(
32+
Tester(self.Expand(new_shape), tuple(inputs))
33+
.export()
34+
.check_node_count({torch.ops.aten.expand.default: 1})
35+
.to_edge_transform_and_lower()
36+
.check_node_count(
37+
{
38+
exir_ops.edge.aten.expand_copy.default: 0,
39+
exir_ops.edge.aten.view_copy.default: 0,
40+
torch.ops.higher_order.executorch_call_delegate: 1,
41+
}
42+
)
43+
.to_executorch()
44+
.run_method_and_compare_outputs()
45+
)
46+
47+
def test_fp32_unsupported_expand(self):
48+
inputs = (torch.randn(1, 8, 12),)
49+
new_shapes = (
50+
(2, 8, 12),
51+
(1, 2, 8, 12),
52+
(2, 1, 8, 12),
53+
)
54+
55+
for new_shape in new_shapes:
56+
(
57+
Tester(self.Expand(new_shape), tuple(inputs))
58+
.export()
59+
.check_node_count({torch.ops.aten.expand.default: 1})
60+
.to_edge_transform_and_lower()
61+
.check_node_count(
62+
{
63+
exir_ops.edge.aten.expand_copy.default: 1,
64+
exir_ops.edge.aten.view_copy.default: 0,
65+
}
66+
)
67+
.to_executorch()
68+
.run_method_and_compare_outputs()
69+
)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.expand_to_view_pass import ExpandToViewPass
11+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
class TestExpandToViewPass(unittest.TestCase):
16+
PassStage = RunPasses([ExpandToViewPass])
17+
18+
class Expand(torch.nn.Module):
19+
def __init__(self, out_shape):
20+
super().__init__()
21+
self.out_shape = out_shape
22+
23+
def forward(self, x):
24+
return x.expand(self.out_shape)
25+
26+
def test_fp32_insert_dim(self):
27+
inputs = (torch.randn(8, 12),)
28+
new_shapes = (
29+
(1, 8, 12),
30+
(1, 1, 8, 12),
31+
)
32+
33+
for new_shape in new_shapes:
34+
(
35+
Tester(self.Expand(new_shape), tuple(inputs))
36+
.export()
37+
.to_edge()
38+
.check_node_count({exir_ops.edge.aten.expand_copy.default: 1})
39+
.run_passes(self.PassStage)
40+
.check_node_count(
41+
{
42+
exir_ops.edge.aten.expand_copy.default: 0,
43+
exir_ops.edge.aten.view_copy.default: 1,
44+
}
45+
)
46+
.run_method_and_compare_outputs()
47+
)
48+
49+
def test_fp32_unsupported_expand(self):
50+
inputs = (torch.randn(1, 8, 12),)
51+
new_shapes = (
52+
(2, 8, 12),
53+
(1, 2, 8, 12),
54+
(2, 1, 8, 12),
55+
)
56+
57+
for new_shape in new_shapes:
58+
(
59+
Tester(self.Expand(new_shape), tuple(inputs))
60+
.export()
61+
.to_edge()
62+
.check_node_count({exir_ops.edge.aten.expand_copy.default: 1})
63+
.run_passes(self.PassStage)
64+
.check_node_count(
65+
{
66+
exir_ops.edge.aten.expand_copy.default: 1,
67+
exir_ops.edge.aten.view_copy.default: 0,
68+
}
69+
)
70+
.run_method_and_compare_outputs()
71+
)

0 commit comments

Comments
 (0)