diff --git a/backends/arm/test/models/test_nn_modules.py b/backends/arm/test/models/test_nn_modules.py index 43fe1f4b3f9..1372c3aabf4 100644 --- a/backends/arm/test/models/test_nn_modules.py +++ b/backends/arm/test/models/test_nn_modules.py @@ -20,12 +20,37 @@ import torch from executorch.backends.arm.test.common import parametrize from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, TosaPipelineBI, TosaPipelineMI, ) example_input = torch.rand(1, 6, 16, 16) + +class SimpleWeightReuseModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, padding=1) + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(16, 1, kernel_size=3, padding=1) + self.layer_scale = torch.nn.Parameter(torch.ones(1, 1, 1) * 1e-2) + + def _apply_inner(self, input_tensor: torch.Tensor) -> torch.Tensor: + out = self.conv1(input_tensor) + out = self.relu(out) + out = self.conv2(out) + out = self.layer_scale * out + return out + + def forward( + self, a: torch.Tensor, b: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + x = self._apply_inner(a) + y = self._apply_inner(x + b) + return (x, y) + + module_tests = [ (torch.nn.Embedding(10, 10), (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),)), (torch.nn.LeakyReLU(), (example_input,)), @@ -46,6 +71,8 @@ ), (torch.rand((10, 32, 64)), torch.rand((20, 32, 64))), ), + # Temporary! + (SimpleWeightReuseModel(), (torch.rand(1, 1, 20, 48), torch.rand(1, 1, 20, 48))), ] input_t = tuple[torch.Tensor] @@ -100,3 +127,25 @@ def test_nn_Modules_BI(test_data): not in str(e) ): raise e + + +@parametrize( + "test_data", + {"SimpleWeightReuseModel": test_parameters["SimpleWeightReuseModel"]}, + xfails={ + "SimpleWeightReuseModel": "RuntimeError: Non-passthrough operation could not run on NPU.", + }, +) +def test_nn_Modules_U55BI(test_data): + module, inputs = test_data + pipeline = EthosU55PipelineBI[input_t]( + module, inputs, "", use_to_edge_transform_and_lower=True + ) + pipeline.pop_stage("check.aten") + pipeline.pop_stage("check_count.exir") + pipeline.pop_stage("check.quant_nodes") + pipeline.pop_stage("check_not.quant_nodes") + try: + pipeline.run() + except RuntimeError as e: + raise e