Skip to content

Commit ad348db

Browse files
Zonglin Pengfacebook-github-bot
authored andcommitted
add simplify pass testing
Differential Revision: D66078183
1 parent c16c8b8 commit ad348db

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

backends/cadence/aot/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,21 @@ python_unittest(
317317
"//executorch/exir/dialects:lib",
318318
],
319319
)
320+
321+
python_unittest(
322+
name = "test_simplify_ops_passes",
323+
srcs = [
324+
"tests/test_simplify_ops_passes.py",
325+
],
326+
supports_static_listing = False,
327+
typing = True,
328+
deps = [
329+
"fbsource//third-party/pypi/parameterized:parameterized",
330+
"//caffe2:torch",
331+
"//executorch/backends/cadence/aot:compiler",
332+
"//executorch/backends/cadence/aot:ops_registrations",
333+
"//executorch/backends/cadence/aot:pass_utils",
334+
"//executorch/backends/cadence/aot:simplify_ops",
335+
"//executorch/exir/dialects:lib",
336+
],
337+
)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
4+
import unittest
5+
from typing import cast, Optional, Tuple
6+
7+
import executorch.backends.cadence.aot.ops_registrations # noqa
8+
import torch
9+
from executorch.backends.cadence.aot.compiler import export_to_edge
10+
from executorch.backends.cadence.aot.pass_utils import count_node
11+
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from parameterized.parameterized import parameterized
14+
from torch.fx.passes.infra.pass_base import PassResult
15+
16+
17+
class TestSimplifyOpsPasses(unittest.TestCase):
18+
@parameterized.expand(
19+
[
20+
[(3, 16, 5), (3, 0, 5), 1, 15, 3, 3],
21+
]
22+
)
23+
@torch.no_grad()
24+
def test_simplify_slice_scatter_op(
25+
self,
26+
in_shape: Tuple[int],
27+
src_shape: Tuple[int],
28+
dim: int,
29+
start: Optional[int] = None,
30+
end: Optional[int] = None,
31+
step: int = 1,
32+
):
33+
class SliceScatter(torch.nn.Module):
34+
def __init__(
35+
self, dim: int, start: Optional[int], end: Optional[int], step: int
36+
):
37+
super().__init__()
38+
self.dim = dim
39+
self.start = start
40+
self.end = end
41+
self.step = step
42+
43+
def forward(self, x: torch.Tensor, y: torch.Tensor):
44+
return torch.slice_scatter(
45+
x, y, self.dim, self.start, self.end, self.step
46+
)
47+
48+
model = SliceScatter(dim, start, end, step)
49+
x = torch.randn(in_shape)
50+
y = torch.randn(src_shape)
51+
graph_module = export_to_edge(model, (x, y)).exported_program().graph_module
52+
53+
p = SimplifySliceOpPass()
54+
55+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
56+
57+
self.assertEqual(
58+
count_node(graph_after_passes, exir_ops.edge.aten.slice_scatter.default), 0
59+
)
60+
61+
@parameterized.expand(
62+
[
63+
[(3, 16, 5), (3, 0, 5), 1, 15, 3, 3],
64+
]
65+
)
66+
@torch.no_grad()
67+
def test_simplify_slice_op(
68+
self,
69+
in_shape: Tuple[int],
70+
src_shape: Tuple[int],
71+
dim: int,
72+
start: Optional[int] = None,
73+
end: Optional[int] = None,
74+
step: int = 1,
75+
):
76+
class SliceCopy(torch.nn.Module):
77+
def __init__(
78+
self, dim: int, start: Optional[int], end: Optional[int], step: int
79+
):
80+
super().__init__()
81+
self.dim = dim
82+
self.start = start
83+
self.end = end
84+
self.step = step
85+
86+
def forward(self, x: torch.Tensor) -> torch.Tensor:
87+
return torch.slice_copy(
88+
x, dim=self.dim, start=self.start, end=self.end, step=self.step
89+
)
90+
91+
# Create a model with single slice copy op.
92+
model = SliceCopy(dim, start, end, step)
93+
x = torch.randn(in_shape)
94+
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
95+
self.assertEqual(
96+
count_node(graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
97+
)
98+
99+
p = SimplifySliceOpPass()
100+
101+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
102+
103+
self.assertEqual(
104+
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
105+
)
106+
self.assertEqual(
107+
count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
108+
)

0 commit comments

Comments
 (0)