Skip to content

Commit 9bb1868

Browse files
committed
Arm backend: Add a simple weight sharing test
1 parent da0c80a commit 9bb1868

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

backends/arm/test/models/test_nn_modules.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,37 @@
2020
import torch
2121
from executorch.backends.arm.test.common import parametrize
2222
from executorch.backends.arm.test.tester.test_pipeline import (
23+
EthosU55PipelineBI,
2324
TosaPipelineBI,
2425
TosaPipelineMI,
2526
)
2627

2728
example_input = torch.rand(1, 6, 16, 16)
2829

30+
31+
class SimpleWeightReuseModel(torch.nn.Module):
32+
def __init__(self):
33+
super().__init__()
34+
self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, padding=1)
35+
self.relu = torch.nn.ReLU()
36+
self.conv2 = torch.nn.Conv2d(16, 1, kernel_size=3, padding=1)
37+
self.layer_scale = torch.nn.Parameter(torch.ones(1, 1, 1) * 1e-2)
38+
39+
def _apply_inner(self, input_tensor: torch.Tensor) -> torch.Tensor:
40+
out = self.conv1(input_tensor)
41+
out = self.relu(out)
42+
out = self.conv2(out)
43+
out = self.layer_scale * out
44+
return out
45+
46+
def forward(
47+
self, a: torch.Tensor, b: torch.Tensor
48+
) -> tuple[torch.Tensor, torch.Tensor]:
49+
x = self._apply_inner(a)
50+
y = self._apply_inner(x + b)
51+
return (x, y)
52+
53+
2954
module_tests = [
3055
(torch.nn.Embedding(10, 10), (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),)),
3156
(torch.nn.LeakyReLU(), (example_input,)),
@@ -46,6 +71,8 @@
4671
),
4772
(torch.rand((10, 32, 64)), torch.rand((20, 32, 64))),
4873
),
74+
# Temporary!
75+
(SimpleWeightReuseModel(), (torch.rand(1, 1, 20, 48), torch.rand(1, 1, 20, 48))),
4976
]
5077

5178
input_t = tuple[torch.Tensor]
@@ -100,3 +127,25 @@ def test_nn_Modules_BI(test_data):
100127
not in str(e)
101128
):
102129
raise e
130+
131+
132+
@parametrize(
133+
"test_data",
134+
{"SimpleWeightReuseModel": test_parameters["SimpleWeightReuseModel"]},
135+
xfails={
136+
"SimpleWeightReuseModel": "RuntimeError: Non-passthrough operation could not run on NPU.",
137+
},
138+
)
139+
def test_nn_Modules_U55BI(test_data):
140+
module, inputs = test_data
141+
pipeline = EthosU55PipelineBI[input_t](
142+
module, inputs, "", use_to_edge_transform_and_lower=True
143+
)
144+
pipeline.pop_stage("check.aten")
145+
pipeline.pop_stage("check_count.exir")
146+
pipeline.pop_stage("check.quant_nodes")
147+
pipeline.pop_stage("check_not.quant_nodes")
148+
try:
149+
pipeline.run()
150+
except RuntimeError as e:
151+
raise e

0 commit comments

Comments
 (0)