Skip to content

Commit 3089788

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add option in memory planning to put shared state on same location across entry points (pytorch#14230)
Summary: API that lets you place the same state tensor on the same id and offset across entry points. Lets you have get and set state more natively in the runtime if the underlying arenas are the same. Reviewed By: GregoryComer Differential Revision: D82250153
1 parent e31cef6 commit 3089788

File tree

5 files changed

+242
-8
lines changed

5 files changed

+242
-8
lines changed

exir/emit/_emitter.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@
9393
from executorch.exir.types import LeafValueSpec, ValueSpec
9494
from torch._subclasses.fake_tensor import FakeTensor
9595

96-
from torch.export.exported_program import ExportedProgram
96+
from torch.export.exported_program import ExportedProgram, ExportGraphSignature
97+
from torch.fx.node import Node
9798
from torch.utils import _pytree as pytree
9899

99100
from typing_extensions import TypeAlias
@@ -1633,6 +1634,28 @@ def placeholder( # noqa: C901
16331634
if isinstance(target, str) and isinstance(spec, TensorSpec):
16341635
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
16351636

1637+
def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:
1638+
"""
1639+
Check if the node is buffer according to the provided graph signature.
1640+
If it is one return its fqn as well
1641+
"""
1642+
if node.op == "placeholder":
1643+
if isinstance(node.target, str):
1644+
if node.target in graph_signature.inputs_to_buffers:
1645+
return True
1646+
return False
1647+
1648+
# If the spec does not appear in the mutable section of the graph signature it still might
1649+
# overall be considered a mutable buffer if it has already been memory planned. This would
1650+
# suggest that the same abstract buffer is mutable in another entry point so we should
1651+
# compel it to be considered mutable in all entry points at emission just as the user did with
1652+
# memory planning.
1653+
is_mutable_buffer |= (
1654+
_is_buffer(self.node, self.exported_program.graph_signature)
1655+
and spec.mem_id is not None
1656+
and spec.mem_offset is not None
1657+
)
1658+
16361659
# If the placeholder has a constant_tag, it is external to the PTE file
16371660
# and requires a fqn and location=TensorDataLocation.EXTERNAL
16381661
if constant_tag is not None:

exir/memory_planning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def verify_graph_input_output(self) -> None:
245245
assert len(specs) > 0, "Expect tensor specs"
246246
specs = list(filter(lambda spec: not spec.const, specs))
247247
if len(specs) == 0:
248+
# all outputs are const so no need to allocate memory just say we suceeded
249+
graph_output_allocated = self.alloc_graph_output
248250
continue
249251
allocated = any(
250252
spec is None or spec.mem_offset is not None for spec in specs
@@ -408,6 +410,7 @@ def collect_specs_from_nodes( # noqa: C901
408410
ignore_graph_input: bool = False,
409411
ignore_graph_output: bool = False,
410412
ignore_mutable_buffers: bool = False,
413+
share_mutable_buffers: bool = False,
411414
ignore_const: bool = True,
412415
ignore_out_var_node: bool = True,
413416
dedup: bool = True,

exir/passes/memory_planning_pass.py

Lines changed: 152 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import logging
89
import warnings
10+
from dataclasses import dataclass, field
911
from functools import partial
10-
from typing import Any, Callable, List, Optional
12+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
1113

1214
import torch
1315
from executorch.exir._warnings import deprecated
@@ -16,14 +18,18 @@
1618
from executorch.exir.memory_planning import (
1719
_is_out_var_node,
1820
apply_algo,
21+
collect_specs_from_nodes,
22+
filter_nodes,
1923
get_node_tensor_specs,
2024
MemoryPlanningAlgorithmSuite,
2125
Verifier,
2226
)
2327
from executorch.exir.operator.convert import get_out_args_from_opoverload
2428
from executorch.exir.pass_base import PassBase, PassResult
25-
from executorch.exir.tensor import ALIGNMENT
29+
from executorch.exir.tensor import ALIGNMENT, TensorSpec
30+
from torch import fx
2631
from torch.export.exported_program import ExportGraphSignature
32+
from torch.fx import Node
2733

2834

2935
# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
@@ -37,6 +43,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
3743
return str(any_callable)
3844

3945

46+
def _is_buffer(
47+
node: Node, graph_signature: ExportGraphSignature
48+
) -> Tuple[bool, Optional[str]]:
49+
"""
50+
Check if the node is buffer according to the provided graph signature.
51+
If it is one return its fqn as well
52+
"""
53+
if node.op == "placeholder":
54+
if isinstance(node.target, str):
55+
if node.target in graph_signature.inputs_to_buffers:
56+
fqn = graph_signature.inputs_to_buffers[node.target]
57+
return (True, fqn)
58+
return (False, None)
59+
60+
61+
def _is_mutable_buffer(
62+
node: Node, graph_signature: ExportGraphSignature
63+
) -> Tuple[bool, Optional[str]]:
64+
"""
65+
Check if the node is mutable buffer according to the provided graph signature.
66+
If it is one return its fqn as well
67+
"""
68+
if node.op == "placeholder":
69+
if isinstance(node.target, str):
70+
if node.target in graph_signature.inputs_to_buffers:
71+
fqn = graph_signature.inputs_to_buffers[node.target]
72+
# if the buffer is mutated then record that
73+
if fqn in graph_signature.buffers_to_mutate.values():
74+
return True, fqn
75+
return False, None
76+
77+
78+
def _get_spec_from_node(node: fx.Node) -> TensorSpec:
79+
specs = get_node_tensor_specs(node)
80+
return specs[0]
81+
82+
83+
def _insert_mutable_buffer_specs(
84+
state: "_MemoryPlanningState", gm: torch.fx.GraphModule, gs: ExportGraphSignature
85+
):
86+
for node in gm.graph.nodes:
87+
is_mutable, fqn = _is_mutable_buffer(node, gs)
88+
if is_mutable:
89+
assert fqn
90+
spec = _get_spec_from_node(node)
91+
if (
92+
getattr(spec, "mem_id", None) is not None
93+
or getattr(spec, "mem_offset", None) is not None
94+
):
95+
raise ValueError(
96+
"Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
97+
)
98+
if fqn not in state.mutable_buffers.keys():
99+
state.mutable_buffers[fqn] = set()
100+
state.mutable_buffers[fqn].add(spec)
101+
continue
102+
is_buffer, fqn = _is_buffer(node, gs)
103+
# If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
104+
# So cache it and later double check that this buffer never appears mutable
105+
if is_buffer:
106+
assert fqn
107+
spec = _get_spec_from_node(node)
108+
if (
109+
getattr(spec, "mem_id", None) is not None
110+
or getattr(spec, "mem_offset", None) is not None
111+
):
112+
raise ValueError(
113+
"Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
114+
)
115+
if fqn not in state.maybe_mutable_buffers.keys():
116+
state.maybe_mutable_buffers[fqn] = set()
117+
state.maybe_mutable_buffers[fqn].add(spec)
118+
119+
120+
def _check_default_mem_ids(gm: torch.fx.GraphModule):
121+
for node in gm.graph.nodes:
122+
for spec in collect_specs_from_nodes(
123+
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
124+
None,
125+
ignore_graph_input=False,
126+
ignore_const=False,
127+
ignore_out_var_node=False,
128+
dedup=False,
129+
do_assertion=False,
130+
ignore_dynamic_unbound_tensor=False,
131+
):
132+
mem_id = getattr(spec, "mem_id", None)
133+
if mem_id is not None and mem_id != 1:
134+
raise ValueError(
135+
"Cannot share mutable buffers if all other tensors are not on the default mem_id of 1"
136+
)
137+
138+
139+
@dataclass
140+
class _MemoryPlanningState:
141+
mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
142+
maybe_mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
143+
graph_modules: List[torch.fx.GraphModule] = field(default_factory=list)
144+
145+
40146
class MemoryPlanningPass(PassBase):
41147
def __init__(
42148
self,
@@ -45,6 +151,7 @@ def __init__(
45151
alloc_graph_input: bool = True,
46152
alloc_graph_output: bool = True,
47153
alloc_mutable_buffers: bool = True,
154+
share_mutable_buffers: bool = False,
48155
alignment: int = ALIGNMENT,
49156
) -> None:
50157
r"""
@@ -55,12 +162,18 @@ def __init__(
55162
"""
56163
if memory_planning_algo is None:
57164
memory_planning_algo = MemoryPlanningAlgorithmSuite()
165+
if share_mutable_buffers and not alloc_mutable_buffers:
166+
raise ValueError(
167+
"share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
168+
)
58169
self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo
59170
self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
60171
self.alloc_graph_input = alloc_graph_input
61172
self.alloc_graph_output = alloc_graph_output
62173
self.alloc_mutable_buffers = alloc_mutable_buffers
174+
self.share_mutable_buffers = share_mutable_buffers
63175
self.alignment = alignment
176+
self.state = _MemoryPlanningState()
64177

65178
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
66179
"""
@@ -134,9 +247,17 @@ def run(
134247
graph_signature,
135248
self.alloc_graph_input,
136249
self.alloc_graph_output,
137-
self.alloc_mutable_buffers,
250+
# If we are sharing the mutable buffers then do not allocate them in
251+
# memory planning algo, instead collect all of the specs over all the entry
252+
# points and then allocate them directly in the run_multimethod name call
253+
self.alloc_mutable_buffers and not self.share_mutable_buffers,
138254
)
139255

256+
if self.share_mutable_buffers and graph_signature is not None:
257+
self.state.graph_modules.append(graph_module)
258+
_check_default_mem_ids(graph_module)
259+
_insert_mutable_buffer_specs(self.state, graph_module, graph_signature)
260+
140261
# TODO: make the verifier do the work recursively to handle
141262
# control flow
142263
verifier = Verifier(
@@ -164,3 +285,31 @@ def run(
164285
# I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
165286
verifier.verify_storage_reuse()
166287
return PassResult(graph_module, True)
288+
289+
def run_multimethod(self):
290+
"Resolve any memory planning done across entry points"
291+
if self.share_mutable_buffers:
292+
arena: int = 0
293+
294+
# Every spec that shares an fqn is the same tensor! So we give it the same id and offset
295+
# anywhere it appears.
296+
for fqn, specs_set in self.state.mutable_buffers.items():
297+
specs = list(specs_set)
298+
# If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
299+
if fqn in self.state.maybe_mutable_buffers.keys():
300+
specs.extend(self.state.maybe_mutable_buffers[fqn])
301+
for spec in specs:
302+
# Assume a default memory planning placed all activations on 1, place shared state on 2.
303+
spec.mem_id = 2
304+
spec.realign(self.alignment)
305+
# State is persistent, so the memory never overlaps.
306+
spec.mem_offset = arena
307+
# They should all be the same size since they are the same tensor, so just bump off the first.
308+
arena += specs[0].allocated_memory
309+
310+
for graph_module in self.state.graph_modules:
311+
if len(graph_module.meta["non_const_buffer_sizes"]) != 2:
312+
raise ValueError(
313+
"Cannot share mutable state if not using default memory ids"
314+
)
315+
graph_module.meta["non_const_buffer_sizes"].append(arena)

exir/program/_program.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,11 +1745,9 @@ def to_executorch(
17451745
memory_planning_pass = config.memory_planning_pass
17461746
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
17471747
if hasattr(memory_planning_pass, "run"):
1748-
new_gm_res = memory_planning_pass.run( # pyre-ignore[16]
1749-
new_gm, new_signature
1750-
)
1748+
new_gm_res = memory_planning_pass.run(new_gm, new_signature)
17511749
else:
1752-
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
1750+
new_gm_res = memory_planning_pass(new_gm)
17531751

17541752
# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
17551753
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
@@ -1758,6 +1756,15 @@ def to_executorch(
17581756

17591757
_copy_module(program.graph_module, new_gm)
17601758
execution_programs[name] = program
1759+
# After running memory planning on all entry points we can run the cross entry point memory planning
1760+
if isinstance(config.memory_planning_pass, dict):
1761+
for memory_planning_pass in config.memory_planning_pass.values():
1762+
if hasattr(memory_planning_pass, "run_multimethod"):
1763+
memory_planning_pass.run_multimethod()
1764+
else:
1765+
memory_planning_pass = config.memory_planning_pass
1766+
if hasattr(memory_planning_pass, "run_multimethod"):
1767+
memory_planning_pass.run_multimethod()
17611768

17621769
et_pm = ExecutorchProgramManager(
17631770
execution_programs,

exir/tests/test_memory_planning.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from executorch.exir import ExecutorchBackendConfig, to_edge
17+
from executorch.exir.capture._capture import patch_forward
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.memory_planning import (
1920
_do_user_inputs_exist,
@@ -93,6 +94,24 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
9394
return (torch.randn(10), torch.randn(10))
9495

9596

97+
class MultiEntryPointStatefulModel(torch.nn.Module):
98+
def __init__(self) -> None:
99+
super().__init__()
100+
self.register_buffer("state", torch.zeros(2, 2))
101+
102+
def forward(self, x: torch.Tensor) -> torch.Tensor:
103+
return self.state.add_(x).view(-1) * 2
104+
105+
def set_state(self, state: torch.Tensor) -> None:
106+
self.state.copy_(state)
107+
108+
def get_state(self) -> torch.Tensor:
109+
return self.state
110+
111+
def get_example_inputs(self) -> Tuple[torch.Tensor, ...]:
112+
return (torch.ones(1),)
113+
114+
96115
class ModelWithDifferentTensorSizes(torch.nn.Module):
97116
def __init__(self) -> None:
98117
super(ModelWithDifferentTensorSizes, self).__init__()
@@ -1081,3 +1100,36 @@ def test_multi_map(self) -> None:
10811100
verifier.storage_overlap(outer_spec, inner_spec),
10821101
f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap",
10831102
)
1103+
1104+
def test_multi_state_plan(self) -> None:
1105+
eager_module = MultiEntryPointStatefulModel().eval()
1106+
forward = export(eager_module, eager_module.get_example_inputs())
1107+
with patch_forward(eager_module, eager_module.get_state):
1108+
get_state = export(eager_module, ())
1109+
with patch_forward(eager_module, eager_module.set_state):
1110+
set_state = export(eager_module, (torch.zeros(1),))
1111+
edge = to_edge(
1112+
{"forward": forward, "set_state": set_state, "get_state": get_state}
1113+
)
1114+
et = edge.to_executorch(
1115+
ExecutorchBackendConfig(
1116+
memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True),
1117+
emit_mutable_buffer_names=True,
1118+
)
1119+
)
1120+
et_prog = et.executorch_program
1121+
count = 0
1122+
for plan in et_prog.execution_plan:
1123+
for value in plan.values:
1124+
if (
1125+
hasattr(value.val, "allocation_info")
1126+
and value.val.allocation_info is not None
1127+
and value.val.allocation_info.memory_id == 2
1128+
):
1129+
count += 1
1130+
self.assertEqual(value.val.allocation_info.memory_offset_low, 0)
1131+
self.assertTrue(value.val.extra_tensor_info is not None)
1132+
self.assertEqual(
1133+
value.val.extra_tensor_info.fully_qualified_name, "state"
1134+
)
1135+
self.assertEqual(count, 3)

0 commit comments

Comments
 (0)