Skip to content

Commit 67d882d

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add option in memory planning to put shared state on same location across entry points (pytorch#14230)
Summary: API that lets you place the same state tensor on the same id and offset across entry points. Lets you have get and set state more natively in the runtime if the underlying arenas are the same. Reviewed By: GregoryComer Differential Revision: D82250153
1 parent e31cef6 commit 67d882d

File tree

5 files changed

+243
-8
lines changed

5 files changed

+243
-8
lines changed

exir/emit/_emitter.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@
9393
from executorch.exir.types import LeafValueSpec, ValueSpec
9494
from torch._subclasses.fake_tensor import FakeTensor
9595

96-
from torch.export.exported_program import ExportedProgram
96+
from torch.export.exported_program import ExportedProgram, ExportGraphSignature
97+
from torch.fx.node import Node
9798
from torch.utils import _pytree as pytree
9899

99100
from typing_extensions import TypeAlias
@@ -1633,6 +1634,28 @@ def placeholder( # noqa: C901
16331634
if isinstance(target, str) and isinstance(spec, TensorSpec):
16341635
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
16351636

1637+
def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:
1638+
"""
1639+
Check if the node is buffer according to the provided graph signature.
1640+
If it is one return its fqn as well
1641+
"""
1642+
if node.op == "placeholder":
1643+
if isinstance(node.target, str):
1644+
if node.target in graph_signature.inputs_to_buffers:
1645+
return True
1646+
return False
1647+
1648+
# If the spec does not appear in the mutable section of the graph signature it still might
1649+
# overall be considered a mutable buffer if it has already been memory planned. This would
1650+
# suggest that the same abstract buffer is mutable in another entry point so we should
1651+
# compel it to be considered mutable in all entry points at emission just as the user did with
1652+
# memory planning.
1653+
is_mutable_buffer |= (
1654+
_is_buffer(self.node, self.exported_program.graph_signature)
1655+
and spec.mem_id is not None
1656+
and spec.mem_offset is not None
1657+
)
1658+
16361659
# If the placeholder has a constant_tag, it is external to the PTE file
16371660
# and requires a fqn and location=TensorDataLocation.EXTERNAL
16381661
if constant_tag is not None:

exir/memory_planning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def verify_graph_input_output(self) -> None:
245245
assert len(specs) > 0, "Expect tensor specs"
246246
specs = list(filter(lambda spec: not spec.const, specs))
247247
if len(specs) == 0:
248+
# all outputs are const so no need to allocate memory just say we suceeded
249+
graph_output_allocated = self.alloc_graph_output
248250
continue
249251
allocated = any(
250252
spec is None or spec.mem_offset is not None for spec in specs
@@ -408,6 +410,7 @@ def collect_specs_from_nodes( # noqa: C901
408410
ignore_graph_input: bool = False,
409411
ignore_graph_output: bool = False,
410412
ignore_mutable_buffers: bool = False,
413+
share_mutable_buffers: bool = False,
411414
ignore_const: bool = True,
412415
ignore_out_var_node: bool = True,
413416
dedup: bool = True,

exir/passes/memory_planning_pass.py

Lines changed: 153 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import logging
89
import warnings
10+
from dataclasses import dataclass, field
911
from functools import partial
10-
from typing import Any, Callable, List, Optional
12+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
1113

1214
import torch
1315
from executorch.exir._warnings import deprecated
@@ -16,14 +18,19 @@
1618
from executorch.exir.memory_planning import (
1719
_is_out_var_node,
1820
apply_algo,
21+
collect_specs_from_nodes,
22+
filter_nodes,
23+
get_node_tensor_specs,
1924
get_node_tensor_specs,
2025
MemoryPlanningAlgorithmSuite,
2126
Verifier,
2227
)
2328
from executorch.exir.operator.convert import get_out_args_from_opoverload
2429
from executorch.exir.pass_base import PassBase, PassResult
25-
from executorch.exir.tensor import ALIGNMENT
30+
from executorch.exir.tensor import ALIGNMENT, TensorSpec
31+
from torch import fx
2632
from torch.export.exported_program import ExportGraphSignature
33+
from torch.fx import Node
2734

2835

2936
# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
@@ -37,6 +44,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
3744
return str(any_callable)
3845

3946

47+
def _is_buffer(
48+
node: Node, graph_signature: ExportGraphSignature
49+
) -> Tuple[bool, Optional[str]]:
50+
"""
51+
Check if the node is buffer according to the provided graph signature.
52+
If it is one return its fqn as well
53+
"""
54+
if node.op == "placeholder":
55+
if isinstance(node.target, str):
56+
if node.target in graph_signature.inputs_to_buffers:
57+
fqn = graph_signature.inputs_to_buffers[node.target]
58+
return (True, fqn)
59+
return (False, None)
60+
61+
62+
def _is_mutable_buffer(
63+
node: Node, graph_signature: ExportGraphSignature
64+
) -> Tuple[bool, Optional[str]]:
65+
"""
66+
Check if the node is mutable buffer according to the provided graph signature.
67+
If it is one return its fqn as well
68+
"""
69+
if node.op == "placeholder":
70+
if isinstance(node.target, str):
71+
if node.target in graph_signature.inputs_to_buffers:
72+
fqn = graph_signature.inputs_to_buffers[node.target]
73+
# if the buffer is mutated then record that
74+
if fqn in graph_signature.buffers_to_mutate.values():
75+
return True, fqn
76+
return False, None
77+
78+
79+
def _get_spec_from_node(node: fx.Node) -> TensorSpec:
80+
specs = get_node_tensor_specs(node)
81+
return specs[0]
82+
83+
84+
def _insert_mutable_buffer_specs(
85+
state: "_MemoryPlanningState", gm: torch.fx.GraphModule, gs: ExportGraphSignature
86+
):
87+
for node in gm.graph.nodes:
88+
is_mutable, fqn = _is_mutable_buffer(node, gs)
89+
if is_mutable:
90+
assert fqn
91+
spec = _get_spec_from_node(node)
92+
if (
93+
getattr(spec, "mem_id", None) is not None
94+
or getattr(spec, "mem_offset", None) is not None
95+
):
96+
raise ValueError(
97+
"Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
98+
)
99+
if fqn not in state.mutable_buffers.keys():
100+
state.mutable_buffers[fqn] = set()
101+
state.mutable_buffers[fqn].add(spec)
102+
continue
103+
is_buffer, fqn = _is_buffer(node, gs)
104+
# If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
105+
# So cache it and later double check that this buffer never appears mutable
106+
if is_buffer:
107+
assert fqn
108+
spec = _get_spec_from_node(node)
109+
if (
110+
getattr(spec, "mem_id", None) is not None
111+
or getattr(spec, "mem_offset", None) is not None
112+
):
113+
raise ValueError(
114+
"Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
115+
)
116+
if fqn not in state.maybe_mutable_buffers.keys():
117+
state.maybe_mutable_buffers[fqn] = set()
118+
state.maybe_mutable_buffers[fqn].add(spec)
119+
120+
121+
def _check_default_mem_ids(gm: torch.fx.GraphModule):
122+
for node in gm.graph.nodes:
123+
for spec in collect_specs_from_nodes(
124+
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
125+
None,
126+
ignore_graph_input=False,
127+
ignore_const=False,
128+
ignore_out_var_node=False,
129+
dedup=False,
130+
do_assertion=False,
131+
ignore_dynamic_unbound_tensor=False,
132+
):
133+
mem_id = getattr(spec, "mem_id", None)
134+
if mem_id is not None and mem_id != 1:
135+
raise ValueError(
136+
"Cannot share mutable buffers if all other tensors are not on the default mem_id of 1"
137+
)
138+
139+
140+
@dataclass
141+
class _MemoryPlanningState:
142+
mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
143+
maybe_mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
144+
graph_modules: List[torch.fx.GraphModule] = field(default_factory=list)
145+
146+
40147
class MemoryPlanningPass(PassBase):
41148
def __init__(
42149
self,
@@ -45,6 +152,7 @@ def __init__(
45152
alloc_graph_input: bool = True,
46153
alloc_graph_output: bool = True,
47154
alloc_mutable_buffers: bool = True,
155+
share_mutable_buffers: bool = False,
48156
alignment: int = ALIGNMENT,
49157
) -> None:
50158
r"""
@@ -55,12 +163,18 @@ def __init__(
55163
"""
56164
if memory_planning_algo is None:
57165
memory_planning_algo = MemoryPlanningAlgorithmSuite()
166+
if share_mutable_buffers and not alloc_mutable_buffers:
167+
raise ValueError(
168+
"share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
169+
)
58170
self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo
59171
self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
60172
self.alloc_graph_input = alloc_graph_input
61173
self.alloc_graph_output = alloc_graph_output
62174
self.alloc_mutable_buffers = alloc_mutable_buffers
175+
self.share_mutable_buffers = share_mutable_buffers
63176
self.alignment = alignment
177+
self.state = _MemoryPlanningState()
64178

65179
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
66180
"""
@@ -134,9 +248,17 @@ def run(
134248
graph_signature,
135249
self.alloc_graph_input,
136250
self.alloc_graph_output,
137-
self.alloc_mutable_buffers,
251+
# If we are sharing the mutable buffers then do not allocate them in
252+
# memory planning algo, instead collect all of the specs over all the entry
253+
# points and then allocate them directly in the run_multimethod name call
254+
self.alloc_mutable_buffers and not self.share_mutable_buffers,
138255
)
139256

257+
if self.share_mutable_buffers and graph_signature is not None:
258+
self.state.graph_modules.append(graph_module)
259+
_check_default_mem_ids(graph_module)
260+
_insert_mutable_buffer_specs(self.state, graph_module, graph_signature)
261+
140262
# TODO: make the verifier do the work recursively to handle
141263
# control flow
142264
verifier = Verifier(
@@ -164,3 +286,31 @@ def run(
164286
# I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
165287
verifier.verify_storage_reuse()
166288
return PassResult(graph_module, True)
289+
290+
def run_multimethod(self):
291+
"Resolve any memory planning done across entry points"
292+
if self.share_mutable_buffers:
293+
arena: int = 0
294+
295+
# Every spec that shares an fqn is the same tensor! So we give it the same id and offset
296+
# anywhere it appears.
297+
for fqn, specs_set in self.state.mutable_buffers.items():
298+
specs = list(specs_set)
299+
# If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
300+
if fqn in self.state.maybe_mutable_buffers.keys():
301+
specs.extend(self.state.maybe_mutable_buffers[fqn])
302+
for spec in specs:
303+
# Assume a default memory planning placed all activations on 1, place shared state on 2.
304+
spec.mem_id = 2
305+
spec.realign(self.alignment)
306+
# State is persistent, so the memory never overlaps.
307+
spec.mem_offset = arena
308+
# They should all be the same size since they are the same tensor, so just bump off the first.
309+
arena += specs[0].allocated_memory
310+
311+
for graph_module in self.state.graph_modules:
312+
if len(graph_module.meta["non_const_buffer_sizes"]) != 2:
313+
raise ValueError(
314+
"Cannot share mutable state if not using default memory ids"
315+
)
316+
graph_module.meta["non_const_buffer_sizes"].append(arena)

exir/program/_program.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,11 +1745,9 @@ def to_executorch(
17451745
memory_planning_pass = config.memory_planning_pass
17461746
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
17471747
if hasattr(memory_planning_pass, "run"):
1748-
new_gm_res = memory_planning_pass.run( # pyre-ignore[16]
1749-
new_gm, new_signature
1750-
)
1748+
new_gm_res = memory_planning_pass.run(new_gm, new_signature)
17511749
else:
1752-
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
1750+
new_gm_res = memory_planning_pass(new_gm)
17531751

17541752
# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
17551753
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
@@ -1758,6 +1756,15 @@ def to_executorch(
17581756

17591757
_copy_module(program.graph_module, new_gm)
17601758
execution_programs[name] = program
1759+
# After running memory planning on all entry points we can run the cross entry point memory planning
1760+
if isinstance(config.memory_planning_pass, dict):
1761+
for memory_planning_pass in config.memory_planning_pass.values():
1762+
if hasattr(memory_planning_pass, "run_multimethod"):
1763+
memory_planning_pass.run_multimethod()
1764+
else:
1765+
memory_planning_pass = config.memory_planning_pass
1766+
if hasattr(memory_planning_pass, "run_multimethod"):
1767+
memory_planning_pass.run_multimethod()
17611768

17621769
et_pm = ExecutorchProgramManager(
17631770
execution_programs,

exir/tests/test_memory_planning.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from executorch.exir import ExecutorchBackendConfig, to_edge
17+
from executorch.exir.capture._capture import patch_forward
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.memory_planning import (
1920
_do_user_inputs_exist,
@@ -93,6 +94,24 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
9394
return (torch.randn(10), torch.randn(10))
9495

9596

97+
class MultiEntryPointStatefulModel(torch.nn.Module):
98+
def __init__(self) -> None:
99+
super().__init__()
100+
self.register_buffer("state", torch.zeros(2, 2))
101+
102+
def forward(self, x: torch.Tensor) -> torch.Tensor:
103+
return self.state.add_(x).view(-1) * 2
104+
105+
def set_state(self, state: torch.Tensor) -> None:
106+
self.state.copy_(state)
107+
108+
def get_state(self) -> torch.Tensor:
109+
return self.state
110+
111+
def get_example_inputs(self) -> Tuple[torch.Tensor, ...]:
112+
return (torch.ones(1),)
113+
114+
96115
class ModelWithDifferentTensorSizes(torch.nn.Module):
97116
def __init__(self) -> None:
98117
super(ModelWithDifferentTensorSizes, self).__init__()
@@ -1081,3 +1100,36 @@ def test_multi_map(self) -> None:
10811100
verifier.storage_overlap(outer_spec, inner_spec),
10821101
f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap",
10831102
)
1103+
1104+
def test_multi_state_plan(self) -> None:
1105+
eager_module = MultiEntryPointStatefulModel().eval()
1106+
forward = export(eager_module, eager_module.get_example_inputs())
1107+
with patch_forward(eager_module, eager_module.get_state):
1108+
get_state = export(eager_module, ())
1109+
with patch_forward(eager_module, eager_module.set_state):
1110+
set_state = export(eager_module, (torch.zeros(1),))
1111+
edge = to_edge(
1112+
{"forward": forward, "set_state": set_state, "get_state": get_state}
1113+
)
1114+
et = edge.to_executorch(
1115+
ExecutorchBackendConfig(
1116+
memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True),
1117+
emit_mutable_buffer_names=True,
1118+
)
1119+
)
1120+
et_prog = et.executorch_program
1121+
count = 0
1122+
for plan in et_prog.execution_plan:
1123+
for value in plan.values:
1124+
if (
1125+
hasattr(value.val, "allocation_info")
1126+
and value.val.allocation_info is not None
1127+
and value.val.allocation_info.memory_id == 2
1128+
):
1129+
count += 1
1130+
self.assertEqual(value.val.allocation_info.memory_offset_low, 0)
1131+
self.assertTrue(value.val.extra_tensor_info is not None)
1132+
self.assertEqual(
1133+
value.val.extra_tensor_info.fully_qualified_name, "state"
1134+
)
1135+
self.assertEqual(count, 3)

0 commit comments

Comments
 (0)