Skip to content

Commit a1ba74c

Browse files
committed
Conditionally support expand_copy in XNNPACK delegate
1 parent 36bdc16 commit a1ba74c

File tree

6 files changed

+241
-0
lines changed

6 files changed

+241
-0
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ConvertToUpsampleBilinear2d,
1919
)
2020
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
21+
from executorch.backends.xnnpack._passes.expand_to_view_pass import ExpandToViewPass
2122
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2223
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
2324
FuseBatchNormWithConvPass,
@@ -62,6 +63,7 @@ def __init__(
6263
ConvertToLinearPass,
6364
ConvertToSDPAPass,
6465
ConstPropPass,
66+
ExpandToViewPass,
6567
FuseBatchNormWithConvPass,
6668
FuseActivationPass,
6769
DecomposeConcatenate,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from executorch.backends.xnnpack.utils.utils import get_input_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.WARNING)
17+
18+
19+
class ExpandToViewPass(ExportPass):
20+
"""
21+
Torch expand_copy can be used as an altenative to unsqueeze. This pass replaces expand_copy nodes
22+
that only add one or more singleton dimensions.
23+
24+
25+
Example:
26+
Before Pass:
27+
expand: "f32" = torch.ops.aten.expand_copy.default(x, (1, -1));
28+
29+
After Pass:
30+
view: "f32" = torch.ops.aten.view_copy.default(x, (1, -1));
31+
"""
32+
33+
def can_transform_expand_node(node: torch.fx.Node) -> bool:
34+
# The node can be converted to a view if the expand only inserts singleton dimensions and
35+
# does not modify any existing dimensions.
36+
in_shape = get_input_node(node, 0).meta["val"].shape
37+
out_shape = node.meta["val"].shape
38+
39+
i = 0 # in-shape index
40+
j = 0 # out-shape index
41+
while j < len(out_shape):
42+
if i >= len(in_shape): # Shape mismatch
43+
return False
44+
elif in_shape[i] == out_shape[j]: # Dims match
45+
i += 1
46+
j += 1
47+
elif out_shape[j] == 1: # Inserted singleton dim
48+
j += 1
49+
else: # Dim mismatch (in_shape[i] != out_shape[i])
50+
return False
51+
52+
return True
53+
54+
55+
def call(self, graph_module: torch.fx.GraphModule):
56+
gm = graph_module
57+
for node in gm.graph.nodes:
58+
if (
59+
node.op == "call_function"
60+
and node.target == exir_ops.edge.aten.expand_copy.default
61+
and ExpandToViewPass.can_transform_expand_node(node)
62+
):
63+
with gm.graph.inserting_after(node):
64+
view_node = gm.graph.create_node(
65+
"call_function",
66+
target=exir_ops.edge.aten.view_copy.default,
67+
args=(node.args[0], node.args[1]),
68+
kwargs=node.kwargs
69+
)
70+
71+
node.replace_all_uses_with(view_node)
72+
gm.graph.erase_node(node)
73+
74+
gm.recompile()
75+
new_gm = super().call(gm).graph_module
76+
return PassResult(new_gm, True)
77+

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ConstantPadConfig,
2626
DeQuantizedPerTensorConfig,
2727
DivConfig,
28+
ExpandCopyConfig,
2829
FloorConfig,
2930
HardswishConfig,
3031
# EluConfig,
@@ -77,6 +78,7 @@
7778
ClampConfig,
7879
DivConfig,
7980
# EluConfig, # Waiting for PyTorch Pin Update
81+
ExpandCopyConfig,
8082
FloorConfig,
8183
HardtanhConfig,
8284
HardswishConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import cast, List, Optional
1111

1212
import torch
13+
from executorch.backends.xnnpack._passes.expand_to_view_pass import ExpandToViewPass
1314
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1415
ConfigPrecisionType,
1516
XNNPartitionerConfig,
@@ -225,6 +226,29 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
225226
return torch.ops.aten.elu.default
226227

227228

229+
class ExpandCopyConfig(GenericNodePartitionerConfig):
230+
target_name = "expand_copy.default"
231+
232+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
233+
return [ConfigPrecisionType.FP32]
234+
235+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
236+
return torch.ops.aten.expand_copy.default
237+
238+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
239+
"""
240+
Only partition expand_copy nodes that can be converted to view_copy (insertion of
241+
singleton dims).
242+
"""
243+
if not self.check_common_constraints(node, ep):
244+
return False
245+
246+
if not ExpandToViewPass.can_transform_expand_node(node):
247+
why(node, reason="only insertion of singleton dims is supported")
248+
return False
249+
return True
250+
251+
228252
class SoftmaxConfig(GenericNodePartitionerConfig):
229253
target_name = "_softmax.default"
230254

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Tester
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
class TestExpand(unittest.TestCase):
14+
class Expand(torch.nn.Module):
15+
def __init__(self, out_shape):
16+
super().__init__()
17+
self.out_shape = out_shape
18+
19+
def forward(self, x):
20+
return x.expand(self.out_shape)
21+
22+
23+
def test_fp32_insert_dim(self):
24+
inputs = (torch.randn(8,12),)
25+
new_shapes = (
26+
(1,8,12),
27+
(1,1,8,12),
28+
)
29+
30+
for new_shape in new_shapes:
31+
(
32+
Tester(self.Expand(new_shape), tuple(inputs))
33+
.export()
34+
.check_node_count({torch.ops.aten.expand.default: 1})
35+
.to_edge_transform_and_lower()
36+
.check_node_count({
37+
exir_ops.edge.aten.expand_copy.default: 0,
38+
exir_ops.edge.aten.view_copy.default: 0,
39+
torch.ops.higher_order.executorch_call_delegate: 1,
40+
})
41+
.to_executorch()
42+
.run_method_and_compare_outputs()
43+
)
44+
45+
def test_fp32_unsupported_expand(self):
46+
inputs = (torch.randn(1,8,12),)
47+
new_shapes = (
48+
(2,8,12),
49+
(1,2,8,12),
50+
(2,1,8,12),
51+
)
52+
53+
for new_shape in new_shapes:
54+
(
55+
Tester(self.Expand(new_shape), tuple(inputs))
56+
.export()
57+
.check_node_count({torch.ops.aten.expand.default: 1})
58+
.to_edge_transform_and_lower()
59+
.check_node_count({
60+
exir_ops.edge.aten.expand_copy.default: 1,
61+
exir_ops.edge.aten.view_copy.default: 0,
62+
})
63+
.to_executorch()
64+
.run_method_and_compare_outputs()
65+
)
66+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import unittest
9+
10+
import torch
11+
from executorch.backends.xnnpack._passes.expand_to_view_pass import ExpandToViewPass
12+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
16+
class TestExpandToViewPass(unittest.TestCase):
17+
PassStage = RunPasses([ExpandToViewPass])
18+
19+
class Expand(torch.nn.Module):
20+
def __init__(self, out_shape):
21+
super().__init__()
22+
self.out_shape = out_shape
23+
24+
def forward(self, x):
25+
return x.expand(self.out_shape)
26+
27+
28+
def test_fp32_insert_dim(self):
29+
inputs = (torch.randn(8,12),)
30+
new_shapes = (
31+
(1,8,12),
32+
(1,1,8,12),
33+
)
34+
35+
for new_shape in new_shapes:
36+
(
37+
Tester(self.Expand(new_shape), tuple(inputs))
38+
.export()
39+
.to_edge()
40+
.check_node_count({exir_ops.edge.aten.expand_copy.default: 1})
41+
.run_passes(self.PassStage)
42+
.check_node_count({
43+
exir_ops.edge.aten.expand_copy.default: 0,
44+
exir_ops.edge.aten.view_copy.default: 1,
45+
})
46+
.run_method_and_compare_outputs()
47+
)
48+
49+
def test_fp32_unsupported_expand(self):
50+
inputs = (torch.randn(1,8,12),)
51+
new_shapes = (
52+
(2,8,12),
53+
(1,2,8,12),
54+
(2,1,8,12),
55+
)
56+
57+
for new_shape in new_shapes:
58+
(
59+
Tester(self.Expand(new_shape), tuple(inputs))
60+
.export()
61+
.to_edge()
62+
.check_node_count({exir_ops.edge.aten.expand_copy.default: 1})
63+
.run_passes(self.PassStage)
64+
.check_node_count({
65+
exir_ops.edge.aten.expand_copy.default: 1,
66+
exir_ops.edge.aten.view_copy.default: 0,
67+
})
68+
.run_method_and_compare_outputs()
69+
)
70+

0 commit comments

Comments
 (0)