From 07e72509134d19fb51cb84e60df38964ea9b771f Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 27 Jan 2025 12:23:55 -0800 Subject: [PATCH] Mark call as deprecated (#7968) Summary: call is deprecated since it cant handle mutation. This is a no op for people using the default memory planning stuff today, but want to call out louder to people implementing their own not to do call. Reviewed By: hsharma35 Differential Revision: D68726718 --- exir/passes/memory_planning_pass.py | 6 ++++++ exir/tests/test_memory_planning.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index 112b8f5fc52..83f17fa141b 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -9,6 +9,7 @@ from typing import Callable, List, Optional import torch +from executorch.exir._warnings import deprecated from executorch.exir.error import internal_assert from executorch.exir.memory import alloc from executorch.exir.memory_planning import ( @@ -83,6 +84,11 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: ) out_alloc_node.meta["spec"] = specs[i] + @deprecated( + "MemoryPlanningPass.call() is deprecated as it does not handle graphs \ + with mutation, please use MemoryPlanningPass.run() instead", + category=FutureWarning, + ) def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return self.run(graph_module) diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 1f94f0341f1..e09cd233c05 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -519,6 +519,39 @@ def test_multiple_pools( idx += 1 self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes) + def test_mutation_not_double_allocated(self) -> None: + class Simple(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("constant", torch.ones(5, 5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.constant.add_(1) + return x - self.constant + + model = Simple() + inputs = (torch.ones(5, 5),) + + et = to_edge(export(model, inputs, strict=True)).to_executorch() + + # 0 and 11 should refer to the same tensor. 0 is the input, 11 is the output of copy_ + self.assertEqual( + et.executorch_program.execution_plan[0] + .values[0] + .val.allocation_info.memory_offset_low, + et.executorch_program.execution_plan[0] + .values[11] + .val.allocation_info.memory_offset_low, + ) + self.assertEqual( + et.executorch_program.execution_plan[0] + .values[0] + .val.allocation_info.memory_offset_high, + et.executorch_program.execution_plan[0] + .values[11] + .val.allocation_info.memory_offset_high, + ) + def test_constants_not_memory_planned(self) -> None: class Simple(torch.nn.Module): def __init__(self) -> None: