Skip to content

Commit 30b033c

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Properly set mutable buffer lifespans (#12182)
Summary: Earlier iterations of mutable buffer memory planning relied on the insert copy_ pass to inject the placeholder node as the output. That is pretty hacky and doesn't compose well with the reinplacing pass. Fortunately we already have this pass so we can manually set the lifespan here to be infinite. Reviewed By: nitish2112, larryliu0820 Differential Revision: D77618047
1 parent 9905026 commit 30b033c

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

exir/memory_planning.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,11 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
301301

302302

303303
def update_tensor_lifetime(
304-
node: torch.fx.Node, spec: TensorSpec, node_idx: int
304+
node: torch.fx.Node,
305+
spec: TensorSpec,
306+
node_idx: int,
307+
max_node_idx: int,
308+
gs: Optional[ExportGraphSignature] = None,
305309
) -> None:
306310
r"""
307311
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
@@ -317,7 +321,12 @@ def update_tensor_lifetime(
317321
start = 0
318322
else:
319323
start = node_idx if start is None or start > node_idx else start
320-
end = node_idx if end is None or end < node_idx else end
324+
325+
if node.op == "placeholder" and _is_mutable_buffer(node, gs):
326+
# mutable buffers are never freed
327+
end = max_node_idx
328+
else:
329+
end = node_idx if end is None or end < node_idx else end
321330
spec.lifetime = [start, end]
322331

323332

@@ -497,7 +506,7 @@ def update_all_tensors_lifetime(
497506
Set the lifetime for all the tensors encountered in the Fx graph.
498507
"""
499508
specs = set()
500-
509+
max_node_idx = len(graph_module.graph.nodes) - 1
501510
for node_idx, node in enumerate(graph_module.graph.nodes):
502511
for spec in collect_specs_from_nodes(
503512
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
@@ -509,7 +518,7 @@ def update_all_tensors_lifetime(
509518
do_assertion=False,
510519
ignore_dynamic_unbound_tensor=False,
511520
):
512-
update_tensor_lifetime(node, spec, node_idx)
521+
update_tensor_lifetime(node, spec, node_idx, max_node_idx, graph_signature)
513522
specs.add(spec)
514523
return specs
515524

exir/tests/test_memory_planning.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
664664
.val.allocation_info.memory_offset_high,
665665
)
666666

667+
def test_mutable_buffers_infinite_lifespan(self) -> None:
668+
class Simple(torch.nn.Module):
669+
def __init__(self) -> None:
670+
super().__init__()
671+
self.register_buffer("state", torch.zeros(1))
672+
673+
def forward(self, x: torch.Tensor) -> torch.Tensor:
674+
self.state.index_put_(
675+
[
676+
torch.tensor([0]),
677+
],
678+
x,
679+
)
680+
y = x + self.state
681+
z = x * y
682+
return z
683+
684+
model = Simple()
685+
inputs = (torch.ones(1),)
686+
687+
et = to_edge(export(model, inputs, strict=True)).to_executorch(
688+
ExecutorchBackendConfig(
689+
emit_mutable_buffer_names=True, run_reinplace_pass=True
690+
)
691+
)
692+
693+
serialized_state = et.executorch_program.execution_plan[0].values[0].val
694+
self.assertEqual(
695+
serialized_state.extra_tensor_info.fully_qualified_name, "state"
696+
)
697+
memory_base = serialized_state.allocation_info.memory_offset_low
698+
memory_size = memory_base + 4 # 4 bytes for a single float
699+
for value in et.executorch_program.execution_plan[0].values[1:]:
700+
val = value.val
701+
if hasattr(val, "allocation_info") and val.allocation_info is not None:
702+
not_overlapping = (
703+
val.allocation_info.memory_offset_low < memory_base
704+
or val.allocation_info.memory_offset_low >= memory_size
705+
)
706+
self.assertTrue(not_overlapping)
707+
667708
def test_constants_not_memory_planned(self) -> None:
668709
class Simple(torch.nn.Module):
669710
def __init__(self) -> None:

0 commit comments

Comments
 (0)