diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 1fc1f0e02fd..df01b4a6419 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -24,7 +24,7 @@ from executorch.exir.tensor import TensorSpec from torch import fx -from torch.export.exported_program import ExportGraphSignature +from torch.export.exported_program import ExportGraphSignature, InputKind from torch.fx import Node from torch.utils._pytree import tree_flatten @@ -247,7 +247,19 @@ def verify_graph_input_output(self) -> None: graph_output_allocated = allocated has_dynamic_unbound_output |= has_dynamic_unbound_tensor - if "placeholder" in check_list: + # only check if inputs are allocated if there are user inputs: + user_inputs_exist = ( + len( + list( + filter( + lambda input: input.kind == InputKind.USER_INPUT, + self.graph_signature.input_specs, + ) + ) + ) + ) > 0 + + if "placeholder" in check_list and user_inputs_exist: assert graph_input_allocated is not None, "graph_input_allocated not set" if not has_dynamic_unbound_input: assert ( diff --git a/exir/program/_program.py b/exir/program/_program.py index 5a9c101a06a..cac4eb4be0b 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -321,6 +321,8 @@ def lift_constant_tensor_pass(ep): new_input_specs.extend(lifted_constants) lifted_constants.clear() new_input_specs.append(s) + if len(lifted_constants) > 0: + new_input_specs = lifted_constants + new_input_specs ep.graph_signature.input_specs = new_input_specs ep.graph_module.recompile() return ep diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 8e40c49e33f..39dbd3f51d3 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1057,6 +1057,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: new_ep.graph_module.code ) + def test_pass_no_user_inputs(self) -> None: + class NoUserInputs(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("a", torch.ones(1)) + + def forward(self) -> torch.Tensor: + return 3 + self.a + + mod = NoUserInputs() + exported_program = export(mod, (), strict=True) + edge = to_edge( + exported_program, + compile_config=EdgeCompileConfig(_skip_dim_order=False), + ) + ep = edge.exported_program() + # because there is no user input, the lifted constant should be the first input. + FileCheck().check("_lifted_tensor_constant1").check( + "b_a" # followed by the buffer input. + ).run(ep.graph_module.code) + + # the graph signature should also be the same: + self.assertEqual( + ep.graph_signature.input_specs[0].arg.name, "_lifted_tensor_constant1" + ) + self.assertEqual(ep.graph_signature.input_specs[1].arg.name, "b_a") + + # Validate that the program successfully passes validation to executorch: + edge.to_executorch() + def test_constant_prop_pass_for_parameter(self) -> None: def count_additions(gm: torch.fx.GraphModule) -> int: return sum(