Skip to content

Commit 065f3ca

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Properly set mutable buffer lifespans (pytorch#12182)
Summary: Earlier iterations of mutable buffer memory planning relied on the insert copy_ pass to inject the placeholder node as the output. That is pretty hacky and doesn't compose well with the reinplacing pass. Fortunately we already have this pass so we can manually set the lifespan here to be infinite. Reviewed By: nitish2112, larryliu0820 Differential Revision: D77618047
1 parent 59e0476 commit 065f3ca

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

exir/memory_planning.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,14 +299,22 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
299299
)
300300
)
301301

302+
def _is_mutable_buffer(node: torch.fx.Node, gs: Optional[ExportGraphSignature]) -> bool:
303+
if gs is None:
304+
return False
305+
if node.target not in gs.inputs_to_buffers:
306+
return False
307+
buf = gs.inputs_to_buffers[node.target]
308+
return buf in gs.buffers_to_mutate.values()
309+
302310

303311
def update_tensor_lifetime(
304-
node: torch.fx.Node, spec: TensorSpec, node_idx: int
312+
node: torch.fx.Node, spec: TensorSpec, node_idx: int, max_node_idx: int, gs: Optional[ExportGraphSignature] = None
305313
) -> None:
306314
r"""
307315
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
308316
are represented by the index of the first and last node referring
309-
that tensor in its inputs/outputs.
317+
that tensor in its inputs/outputs.
310318
311319
Arguments:
312320
spec: the TensorSpec for the tensor
@@ -317,7 +325,12 @@ def update_tensor_lifetime(
317325
start = 0
318326
else:
319327
start = node_idx if start is None or start > node_idx else start
320-
end = node_idx if end is None or end < node_idx else end
328+
329+
if node.op == "placeholder" and _is_mutable_buffer(node, gs):
330+
# mutable buffers are never freed
331+
end = max_node_idx
332+
else:
333+
end = node_idx if end is None or end < node_idx else end
321334
spec.lifetime = [start, end]
322335

323336

@@ -497,7 +510,7 @@ def update_all_tensors_lifetime(
497510
Set the lifetime for all the tensors encountered in the Fx graph.
498511
"""
499512
specs = set()
500-
513+
max_node_idx = len(graph_module.graph.nodes) - 1
501514
for node_idx, node in enumerate(graph_module.graph.nodes):
502515
for spec in collect_specs_from_nodes(
503516
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
@@ -509,7 +522,7 @@ def update_all_tensors_lifetime(
509522
do_assertion=False,
510523
ignore_dynamic_unbound_tensor=False,
511524
):
512-
update_tensor_lifetime(node, spec, node_idx)
525+
update_tensor_lifetime(node, spec, node_idx, max_node_idx, graph_signature)
513526
specs.add(spec)
514527
return specs
515528

exir/tests/test_memory_planning.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
664664
.val.allocation_info.memory_offset_high,
665665
)
666666

667+
def test_mutable_buffers_infinite_lifespan(self) -> None:
668+
class Simple(torch.nn.Module):
669+
def __init__(self) -> None:
670+
super().__init__()
671+
self.register_buffer("state", torch.zeros(1))
672+
673+
def forward(self, x: torch.Tensor) -> torch.Tensor:
674+
self.state.index_put_([torch.tensor([0]),], x)
675+
y = x + self.state
676+
z = x * y
677+
return z
678+
679+
model = Simple()
680+
inputs = (torch.ones(1),)
681+
682+
et = to_edge(export(model, inputs, strict=True)).to_executorch(ExecutorchBackendConfig(emit_mutable_buffer_names=True, run_reinplace_pass=True))
683+
684+
serialized_state = et.executorch_program.execution_plan[0].values[0].val
685+
self.assertEqual(serialized_state.extra_tensor_info.fully_qualified_name, "state")
686+
memory_base = serialized_state.allocation_info.memory_offset_low
687+
memory_size = memory_base + 4 # 4 bytes for a single float
688+
for value in et.executorch_program.execution_plan[0].values[1:]:
689+
val = value.val
690+
if hasattr(val, "allocation_info") and val.allocation_info is not None:
691+
not_overlapping = val.allocation_info.memory_offset_low < memory_base or val.allocation_info.memory_offset_low >= memory_size
692+
self.assertTrue(not_overlapping)
693+
667694
def test_constants_not_memory_planned(self) -> None:
668695
class Simple(torch.nn.Module):
669696
def __init__(self) -> None:

0 commit comments

Comments
 (0)