|
20 | 20 | import torch |
21 | 21 | from executorch.backends.arm.test.common import parametrize |
22 | 22 | from executorch.backends.arm.test.tester.test_pipeline import ( |
| 23 | + EthosU55PipelineBI, |
23 | 24 | TosaPipelineBI, |
24 | 25 | TosaPipelineMI, |
25 | 26 | ) |
26 | 27 |
|
27 | 28 | example_input = torch.rand(1, 6, 16, 16) |
28 | 29 |
|
| 30 | + |
| 31 | +class SimpleWeightReuseModel(torch.nn.Module): |
| 32 | + def __init__(self): |
| 33 | + super().__init__() |
| 34 | + self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, padding=1) |
| 35 | + self.relu = torch.nn.ReLU() |
| 36 | + self.conv2 = torch.nn.Conv2d(16, 1, kernel_size=3, padding=1) |
| 37 | + self.layer_scale = torch.nn.Parameter(torch.ones(1, 1, 1) * 1e-2) |
| 38 | + |
| 39 | + def _apply_inner(self, input_tensor: torch.Tensor) -> torch.Tensor: |
| 40 | + out = self.conv1(input_tensor) |
| 41 | + out = self.relu(out) |
| 42 | + out = self.conv2(out) |
| 43 | + out = self.layer_scale * out |
| 44 | + return out |
| 45 | + |
| 46 | + def forward( |
| 47 | + self, a: torch.Tensor, b: torch.Tensor |
| 48 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 49 | + x = self._apply_inner(a) |
| 50 | + y = self._apply_inner(x + b) |
| 51 | + return (x, y) |
| 52 | + |
| 53 | + |
29 | 54 | module_tests = [ |
30 | 55 | (torch.nn.Embedding(10, 10), (torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),)), |
31 | 56 | (torch.nn.LeakyReLU(), (example_input,)), |
|
46 | 71 | ), |
47 | 72 | (torch.rand((10, 32, 64)), torch.rand((20, 32, 64))), |
48 | 73 | ), |
| 74 | + # Temporary! |
| 75 | + (SimpleWeightReuseModel(), (torch.rand(1, 1, 20, 48), torch.rand(1, 1, 20, 48))), |
49 | 76 | ] |
50 | 77 |
|
51 | 78 | input_t = tuple[torch.Tensor] |
@@ -100,3 +127,25 @@ def test_nn_Modules_BI(test_data): |
100 | 127 | not in str(e) |
101 | 128 | ): |
102 | 129 | raise e |
| 130 | + |
| 131 | + |
| 132 | +@parametrize( |
| 133 | + "test_data", |
| 134 | + {"SimpleWeightReuseModel": test_parameters["SimpleWeightReuseModel"]}, |
| 135 | + xfails={ |
| 136 | + "SimpleWeightReuseModel": "RuntimeError: Non-passthrough operation could not run on NPU.", |
| 137 | + }, |
| 138 | +) |
| 139 | +def test_nn_Modules_U55BI(test_data): |
| 140 | + module, inputs = test_data |
| 141 | + pipeline = EthosU55PipelineBI[input_t]( |
| 142 | + module, inputs, "", use_to_edge_transform_and_lower=True |
| 143 | + ) |
| 144 | + pipeline.pop_stage("check.aten") |
| 145 | + pipeline.pop_stage("check_count.exir") |
| 146 | + pipeline.pop_stage("check.quant_nodes") |
| 147 | + pipeline.pop_stage("check_not.quant_nodes") |
| 148 | + try: |
| 149 | + pipeline.run() |
| 150 | + except RuntimeError as e: |
| 151 | + raise e |
0 commit comments