From 19371e3a5de312315fa21ac197a93af8b1a44592 Mon Sep 17 00:00:00 2001 From: jhels <11036537+jhels@users.noreply.github.com> Date: Fri, 9 May 2025 20:48:23 +0200 Subject: [PATCH] Update MemoryPlanning Verifier to only assume model has user input if it has at least one tensor input MemoryPlanning verifier currently blows up if all the user inputs are prims. This change means its helper function _do_user_inputs_exist returns false if all its inputs are prims. It only returns true if at least one input is a tensor. --- exir/memory_planning.py | 31 ++++++++++++------ exir/tests/test_memory_planning.py | 51 ++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 83598940882..030ade687a8 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -25,7 +25,11 @@ from executorch.exir.tensor import TensorSpec from torch import fx -from torch.export.exported_program import ExportGraphSignature, InputKind +from torch.export.exported_program import ( + ConstantArgument, + ExportGraphSignature, + InputKind, +) from torch.fx import Node from torch.utils._pytree import tree_flatten @@ -338,16 +342,23 @@ def _do_user_inputs_exist(graph_signature: Optional[ExportGraphSignature]) -> bo if graph_signature is None: return False - return ( - len( - list( - filter( - lambda input: input.kind == InputKind.USER_INPUT, - graph_signature.input_specs, - ) - ) + user_inputs = list( + filter( + lambda input: input.kind == InputKind.USER_INPUT, + graph_signature.input_specs, ) - ) > 0 + ) + + # Return false if: + # - there are no inputs. + # - if user inputs are all prims (as this currently + # causes the memory planning verifier to blow up). + # Otherwise, return true. + return any( + not isinstance(input.arg, ConstantArgument) + or not isinstance(input.arg.value, (int, float, bool, str)) + for input in user_inputs + ) def get_graph_input_tensors( diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index b87ae2dfb58..6b895f27922 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -16,6 +16,7 @@ from executorch.exir import ExecutorchBackendConfig, to_edge from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( + _do_user_inputs_exist, filter_nodes, get_node_tensor_specs, greedy, @@ -307,6 +308,56 @@ def wrapper(self: "TestMemoryPlanning") -> None: return wrapper +class TestMemoryPlanningUserInputs(unittest.TestCase): + """ + Ensure that MemoryPlanning Verifer only assumes a model + has a user input if it has at least one tensor input. + """ + + def test_tensor_only_inputs(self): + class TensorModel(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + model = TensorModel() + inputs = (torch.randn(2), torch.randn(2)) + ep = export(model, inputs, strict=True) + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) + self.assertTrue(result) + + def test_mixed_inputs(self): + class MixedModel(torch.nn.Module): + def forward(self, x: torch.Tensor, y: int) -> torch.Tensor: + return x * y + + model = MixedModel() + inputs = (torch.randn(2), 3) + ep = export(model, inputs, strict=True) + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) + self.assertTrue(result) + + def test_primitive_only_inputs(self): + class PrimModel(torch.nn.Module): + def forward(self, x: int, y: float) -> float: + return x * y + + model = PrimModel() + inputs = (2, 3.0) + ep = export(model, inputs, strict=True) + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) + self.assertFalse(result) + + def test_no_inputs(self): + class NoInputModel(torch.nn.Module): + def forward(self) -> torch.Tensor: + return torch.tensor(1.0) + + model = NoInputModel() + ep = export(model, (), strict=True) + result = _do_user_inputs_exist(graph_signature=ep.graph_signature) + self.assertFalse(result) + + class TestMemoryPlanning(unittest.TestCase): def verify_reuse( self,