Skip to content

Commit 375e4fe

Browse files
authored
Update MemoryPlanning Verifier to only assume model has user input if it has at least one tensor input (#10617)
Fixes #10522 ### Summary MemoryPlanning verifier currently blows up if all the user inputs are prims. #10522 suggested: > We need to improve the logic of the _do_user_inputs_exist to probably just return false if all the inputs are prims. This PR implements this suggestion, with accompanying unit tests. `_do_user_inputs_exist` now returns True if it has at least one tensor input, and False otherwise. ### Test plan Added unit tests to `test_memory_planning.py` and ran them with ```bash pytest exir/tests/test_memory_planning.py::test_memory_planning.py ``` Please note you must comment out the line 69 from `/pytest.ini` for this to work: ``` --ignore=exir/tests/test_memory_planning.py ``` On my machine, I also had to comment out line 60 from `test_memory_planning.py` for the test to run without errors. ``` torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib") ``` [The tests I wrote aren't dependent on this library.] Co-authored-by: jhels <[email protected]>
1 parent 2e56274 commit 375e4fe

File tree

2 files changed

+72
-10
lines changed

2 files changed

+72
-10
lines changed

exir/memory_planning.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from executorch.exir.tensor import TensorSpec
2626

2727
from torch import fx
28-
from torch.export.exported_program import ExportGraphSignature, InputKind
28+
from torch.export.exported_program import (
29+
ConstantArgument,
30+
ExportGraphSignature,
31+
InputKind,
32+
)
2933
from torch.fx import Node
3034
from torch.utils._pytree import tree_flatten
3135

@@ -338,16 +342,23 @@ def _do_user_inputs_exist(graph_signature: Optional[ExportGraphSignature]) -> bo
338342
if graph_signature is None:
339343
return False
340344

341-
return (
342-
len(
343-
list(
344-
filter(
345-
lambda input: input.kind == InputKind.USER_INPUT,
346-
graph_signature.input_specs,
347-
)
348-
)
345+
user_inputs = list(
346+
filter(
347+
lambda input: input.kind == InputKind.USER_INPUT,
348+
graph_signature.input_specs,
349349
)
350-
) > 0
350+
)
351+
352+
# Return false if:
353+
# - there are no inputs.
354+
# - if user inputs are all prims (as this currently
355+
# causes the memory planning verifier to blow up).
356+
# Otherwise, return true.
357+
return any(
358+
not isinstance(input.arg, ConstantArgument)
359+
or not isinstance(input.arg.value, (int, float, bool, str))
360+
for input in user_inputs
361+
)
351362

352363

353364
def get_graph_input_tensors(

exir/tests/test_memory_planning.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.exir import ExecutorchBackendConfig, to_edge
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.memory_planning import (
19+
_do_user_inputs_exist,
1920
filter_nodes,
2021
get_node_tensor_specs,
2122
greedy,
@@ -307,6 +308,56 @@ def wrapper(self: "TestMemoryPlanning") -> None:
307308
return wrapper
308309

309310

311+
class TestMemoryPlanningUserInputs(unittest.TestCase):
312+
"""
313+
Ensure that MemoryPlanning Verifer only assumes a model
314+
has a user input if it has at least one tensor input.
315+
"""
316+
317+
def test_tensor_only_inputs(self):
318+
class TensorModel(torch.nn.Module):
319+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
320+
return x + y
321+
322+
model = TensorModel()
323+
inputs = (torch.randn(2), torch.randn(2))
324+
ep = export(model, inputs, strict=True)
325+
result = _do_user_inputs_exist(graph_signature=ep.graph_signature)
326+
self.assertTrue(result)
327+
328+
def test_mixed_inputs(self):
329+
class MixedModel(torch.nn.Module):
330+
def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
331+
return x * y
332+
333+
model = MixedModel()
334+
inputs = (torch.randn(2), 3)
335+
ep = export(model, inputs, strict=True)
336+
result = _do_user_inputs_exist(graph_signature=ep.graph_signature)
337+
self.assertTrue(result)
338+
339+
def test_primitive_only_inputs(self):
340+
class PrimModel(torch.nn.Module):
341+
def forward(self, x: int, y: float) -> float:
342+
return x * y
343+
344+
model = PrimModel()
345+
inputs = (2, 3.0)
346+
ep = export(model, inputs, strict=True)
347+
result = _do_user_inputs_exist(graph_signature=ep.graph_signature)
348+
self.assertFalse(result)
349+
350+
def test_no_inputs(self):
351+
class NoInputModel(torch.nn.Module):
352+
def forward(self) -> torch.Tensor:
353+
return torch.tensor(1.0)
354+
355+
model = NoInputModel()
356+
ep = export(model, (), strict=True)
357+
result = _do_user_inputs_exist(graph_signature=ep.graph_signature)
358+
self.assertFalse(result)
359+
360+
310361
class TestMemoryPlanning(unittest.TestCase):
311362
def verify_reuse(
312363
self,

0 commit comments

Comments
 (0)