|
16 | 16 | from executorch.exir import ExecutorchBackendConfig, to_edge
|
17 | 17 | from executorch.exir.dialects._ops import ops as exir_ops
|
18 | 18 | from executorch.exir.memory_planning import (
|
| 19 | + _do_user_inputs_exist, |
19 | 20 | filter_nodes,
|
20 | 21 | get_node_tensor_specs,
|
21 | 22 | greedy,
|
@@ -307,6 +308,56 @@ def wrapper(self: "TestMemoryPlanning") -> None:
|
307 | 308 | return wrapper
|
308 | 309 |
|
309 | 310 |
|
| 311 | +class TestMemoryPlanningUserInputs(unittest.TestCase): |
| 312 | + """ |
| 313 | + Ensure that MemoryPlanning Verifer only assumes a model |
| 314 | + has a user input if it has at least one tensor input. |
| 315 | + """ |
| 316 | + |
| 317 | + def test_tensor_only_inputs(self): |
| 318 | + class TensorModel(torch.nn.Module): |
| 319 | + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 320 | + return x + y |
| 321 | + |
| 322 | + model = TensorModel() |
| 323 | + inputs = (torch.randn(2), torch.randn(2)) |
| 324 | + ep = export(model, inputs, strict=True) |
| 325 | + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) |
| 326 | + self.assertTrue(result) |
| 327 | + |
| 328 | + def test_mixed_inputs(self): |
| 329 | + class MixedModel(torch.nn.Module): |
| 330 | + def forward(self, x: torch.Tensor, y: int) -> torch.Tensor: |
| 331 | + return x * y |
| 332 | + |
| 333 | + model = MixedModel() |
| 334 | + inputs = (torch.randn(2), 3) |
| 335 | + ep = export(model, inputs, strict=True) |
| 336 | + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) |
| 337 | + self.assertTrue(result) |
| 338 | + |
| 339 | + def test_primitive_only_inputs(self): |
| 340 | + class PrimModel(torch.nn.Module): |
| 341 | + def forward(self, x: int, y: float) -> float: |
| 342 | + return x * y |
| 343 | + |
| 344 | + model = PrimModel() |
| 345 | + inputs = (2, 3.0) |
| 346 | + ep = export(model, inputs, strict=True) |
| 347 | + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) |
| 348 | + self.assertFalse(result) |
| 349 | + |
| 350 | + def test_no_inputs(self): |
| 351 | + class NoInputModel(torch.nn.Module): |
| 352 | + def forward(self) -> torch.Tensor: |
| 353 | + return torch.tensor(1.0) |
| 354 | + |
| 355 | + model = NoInputModel() |
| 356 | + ep = export(model, (), strict=True) |
| 357 | + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) |
| 358 | + self.assertFalse(result) |
| 359 | + |
| 360 | + |
310 | 361 | class TestMemoryPlanning(unittest.TestCase):
|
311 | 362 | def verify_reuse(
|
312 | 363 | self,
|
|
0 commit comments