|
16 | 16 | import executorch.exir.tests.models as models
|
17 | 17 | import torch
|
18 | 18 | from executorch.exir import CaptureConfig, EdgeCompileConfig, ExecutorchProgram
|
| 19 | +from executorch.exir.backend.backend_api import to_backend |
| 20 | +from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult |
19 | 21 | from executorch.exir.emit import emit_program # noqa
|
20 | 22 | from executorch.exir.error import InternalError
|
21 | 23 | from executorch.exir.passes.const_prop_pass import ConstPropPass
|
|
42 | 44 | _load_for_executorch_from_buffer,
|
43 | 45 | )
|
44 | 46 | from functorch.experimental import control_flow
|
| 47 | +from torch import nn |
45 | 48 |
|
46 | 49 |
|
47 | 50 | class TestEmit(unittest.TestCase):
|
@@ -1197,3 +1200,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1197 | 1200 | idx,
|
1198 | 1201 | node.meta.get("debug_handle"),
|
1199 | 1202 | )
|
| 1203 | + |
| 1204 | + def test_delegate_with_input_list(self) -> None: |
| 1205 | + class BackendWithCompilerDemo(BackendDetails): |
| 1206 | + @staticmethod |
| 1207 | + def preprocess( |
| 1208 | + edge_program, |
| 1209 | + compile_specs, |
| 1210 | + ) -> bytes: |
| 1211 | + return PreprocessResult( |
| 1212 | + processed_bytes=bytes(str("test"), encoding="utf8"), |
| 1213 | + debug_handle_map=None, |
| 1214 | + ) |
| 1215 | + |
| 1216 | + class TestModel(nn.Module): |
| 1217 | + def __init__(self): |
| 1218 | + super(TestModel, self).__init__() |
| 1219 | + |
| 1220 | + def forward(self, x): |
| 1221 | + return torch.cat(x) |
| 1222 | + |
| 1223 | + inputs = ([torch.ones(2, 2), torch.ones(2, 2)],) |
| 1224 | + model = TestModel() |
| 1225 | + edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge( |
| 1226 | + exir.EdgeCompileConfig(_check_ir_validity=False) |
| 1227 | + ) |
| 1228 | + lowered_module = to_backend( |
| 1229 | + "BackendWithCompilerDemo", edgeir_m.exported_program, None |
| 1230 | + ) |
| 1231 | + |
| 1232 | + class CompositeModule(torch.nn.Module): |
| 1233 | + def __init__(self): |
| 1234 | + super().__init__() |
| 1235 | + self.lowered_module = lowered_module |
| 1236 | + |
| 1237 | + def forward(self, list_a): |
| 1238 | + return self.lowered_module(list_a) |
| 1239 | + |
| 1240 | + composite_model = CompositeModule() |
| 1241 | + exec_prog = ( |
| 1242 | + exir.capture(composite_model, inputs, exir.CaptureConfig()) |
| 1243 | + .to_edge() |
| 1244 | + .to_executorch() |
| 1245 | + ) |
| 1246 | + exec_prog.buffer |
0 commit comments