Skip to content

Commit 77d31ee

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add option in memory planning to put shared state on same location across entry points (#14230)
Summary: API that lets you place the same state tensor on the same id and offset across entry points. Lets you have get and set state more natively in the runtime if the underlying arenas are the same. Reviewed By: GregoryComer Differential Revision: D82250153
1 parent bd653ba commit 77d31ee

File tree

5 files changed

+210
-6
lines changed

5 files changed

+210
-6
lines changed

exir/emit/_emitter.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@
9393
from executorch.exir.types import LeafValueSpec, ValueSpec
9494
from torch._subclasses.fake_tensor import FakeTensor
9595

96-
from torch.export.exported_program import ExportedProgram
96+
from torch.export.exported_program import ExportedProgram, ExportGraphSignature
97+
from torch.fx.node import Node
9798
from torch.utils import _pytree as pytree
9899

99100
from typing_extensions import TypeAlias
@@ -1633,6 +1634,28 @@ def placeholder( # noqa: C901
16331634
if isinstance(target, str) and isinstance(spec, TensorSpec):
16341635
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
16351636

1637+
def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:
1638+
"""
1639+
Check if the node is buffer according to the provided graph signature.
1640+
If it is one return its fqn as well
1641+
"""
1642+
if node.op == "placeholder":
1643+
if isinstance(node.target, str):
1644+
if node.target in graph_signature.inputs_to_buffers:
1645+
return True
1646+
return False
1647+
1648+
# If the spec does not appear in the mutable section of the graph signature it still might
1649+
# overall be considered a mutable buffer if it has already been memory planned. This would
1650+
# suggest that the same abstract buffer is mutable in another entry point so we should
1651+
# compel it to be considered mutable in all entry points at emission just as the user did with
1652+
# memory planning.
1653+
is_mutable_buffer |= (
1654+
_is_buffer(self.node, self.exported_program.graph_signature)
1655+
and spec.mem_id is not None
1656+
and spec.mem_offset is not None
1657+
)
1658+
16361659
# If the placeholder has a constant_tag, it is external to the PTE file
16371660
# and requires a fqn and location=TensorDataLocation.EXTERNAL
16381661
if constant_tag is not None:

exir/memory_planning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def verify_graph_input_output(self) -> None:
245245
assert len(specs) > 0, "Expect tensor specs"
246246
specs = list(filter(lambda spec: not spec.const, specs))
247247
if len(specs) == 0:
248+
# all outputs are const so no need to allocate memory just say we suceeded
249+
graph_output_allocated = self.alloc_graph_output
248250
continue
249251
allocated = any(
250252
spec is None or spec.mem_offset is not None for spec in specs
@@ -408,6 +410,7 @@ def collect_specs_from_nodes( # noqa: C901
408410
ignore_graph_input: bool = False,
409411
ignore_graph_output: bool = False,
410412
ignore_mutable_buffers: bool = False,
413+
share_mutable_buffers: bool = False,
411414
ignore_const: bool = True,
412415
ignore_out_var_node: bool = True,
413416
dedup: bool = True,

exir/passes/memory_planning_pass.py

Lines changed: 132 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import logging
89
import warnings
10+
from dataclasses import dataclass, field
911
from functools import partial
10-
from typing import Any, Callable, List, Optional
12+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
1113

1214
import torch
1315
from executorch.exir._warnings import deprecated
@@ -16,14 +18,20 @@
1618
from executorch.exir.memory_planning import (
1719
_is_out_var_node,
1820
apply_algo,
21+
collect_specs_from_nodes,
22+
filter_nodes,
23+
get_node_tensor_specs,
1924
get_node_tensor_specs,
2025
MemoryPlanningAlgorithmSuite,
26+
naive,
2127
Verifier,
2228
)
2329
from executorch.exir.operator.convert import get_out_args_from_opoverload
2430
from executorch.exir.pass_base import PassBase, PassResult
25-
from executorch.exir.tensor import ALIGNMENT
31+
from executorch.exir.tensor import ALIGNMENT, TensorSpec
32+
from torch import fx
2633
from torch.export.exported_program import ExportGraphSignature
34+
from torch.fx import Node
2735

2836

2937
# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
@@ -36,6 +44,84 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
3644
except AttributeError:
3745
return str(any_callable)
3846

47+
def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> Tuple[bool, Optional[str]]:
48+
"""
49+
Check if the node is buffer according to the provided graph signature.
50+
If it is one return its fqn as well
51+
"""
52+
if node.op == "placeholder":
53+
if isinstance(node.target, str):
54+
if node.target in graph_signature.inputs_to_buffers:
55+
fqn = graph_signature.inputs_to_buffers[node.target]
56+
return True, fqn
57+
return False, None
58+
59+
def _is_mutable_buffer(
60+
node: Node, graph_signature: ExportGraphSignature
61+
) -> Tuple[bool, Optional[str]]:
62+
"""
63+
Check if the node is mutable buffer according to the provided graph signature.
64+
If it is one return its fqn as well
65+
"""
66+
if node.op == "placeholder":
67+
if isinstance(node.target, str):
68+
if node.target in graph_signature.inputs_to_buffers:
69+
fqn = graph_signature.inputs_to_buffers[node.target]
70+
# if the buffer is mutated then record that
71+
if fqn in graph_signature.buffers_to_mutate.values():
72+
return True, fqn
73+
return False, None
74+
75+
def _get_spec_from_node(node: fx.Node) -> TensorSpec:
76+
specs = get_node_tensor_specs(node)
77+
assert(len(specs) == 1)
78+
return specs[0]
79+
80+
def _insert_mutable_buffer_specs(state: "_MemoryPlanningState", gm: torch.fx.GraphModule, gs: ExportGraphSignature):
81+
for node in gm.graph.nodes:
82+
is_mutable, fqn = _is_mutable_buffer(node, gs)
83+
if is_mutable:
84+
assert(fqn)
85+
spec = _get_spec_from_node(node)
86+
if getattr(spec, 'mem_id', None) is not None or getattr(spec, 'mem_offset', None) is not None:
87+
raise ValueError("Cannot share mutable buffers if they already have a mem_id or mem_offset assigned")
88+
if fqn not in state.mutable_buffers.keys():
89+
state.mutable_buffers[fqn] = set()
90+
state.mutable_buffers[fqn].add(spec)
91+
continue
92+
is_buffer, fqn = _is_buffer(node, gs)
93+
# If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
94+
# So cache it and later double check that this buffer never appears mutable
95+
if is_buffer:
96+
assert(fqn)
97+
spec = _get_spec_from_node(node)
98+
if getattr(spec, 'mem_id', None) is not None or getattr(spec, 'mem_offset', None) is not None:
99+
raise ValueError("Cannot share mutable buffers if they already have a mem_id or mem_offset assigned")
100+
if fqn not in state.maybe_mutable_buffers.keys():
101+
state.maybe_mutable_buffers[fqn] = set()
102+
state.maybe_mutable_buffers[fqn].add(spec)
103+
104+
def _check_default_mem_ids(gm: torch.fx.GraphModule):
105+
for node in gm.graph.nodes:
106+
for spec in collect_specs_from_nodes(
107+
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
108+
None,
109+
ignore_graph_input=False,
110+
ignore_const=False,
111+
ignore_out_var_node=False,
112+
dedup=False,
113+
do_assertion=False,
114+
ignore_dynamic_unbound_tensor=False,
115+
):
116+
mem_id = getattr(spec, 'mem_id', None)
117+
if mem_id is not None and mem_id != 1:
118+
raise ValueError("Cannot share mutable buffers if all other tensors are not on the default mem_id of 1")
119+
120+
@dataclass
121+
class _MemoryPlanningState:
122+
mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
123+
maybe_mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
124+
graph_modules: List[torch.fx.GraphModule] = field(default_factory=list)
39125

40126
class MemoryPlanningPass(PassBase):
41127
def __init__(
@@ -45,6 +131,7 @@ def __init__(
45131
alloc_graph_input: bool = True,
46132
alloc_graph_output: bool = True,
47133
alloc_mutable_buffers: bool = True,
134+
share_mutable_buffers: bool = False,
48135
alignment: int = ALIGNMENT,
49136
) -> None:
50137
r"""
@@ -55,12 +142,18 @@ def __init__(
55142
"""
56143
if memory_planning_algo is None:
57144
memory_planning_algo = MemoryPlanningAlgorithmSuite()
145+
if share_mutable_buffers and not alloc_mutable_buffers:
146+
raise ValueError(
147+
"share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
148+
)
58149
self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo
59150
self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
60151
self.alloc_graph_input = alloc_graph_input
61152
self.alloc_graph_output = alloc_graph_output
62153
self.alloc_mutable_buffers = alloc_mutable_buffers
154+
self.share_mutable_buffers = share_mutable_buffers
63155
self.alignment = alignment
156+
self.state = _MemoryPlanningState()
64157

65158
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
66159
"""
@@ -134,9 +227,17 @@ def run(
134227
graph_signature,
135228
self.alloc_graph_input,
136229
self.alloc_graph_output,
137-
self.alloc_mutable_buffers,
230+
# If we are sharing the mutable buffers then do not allocate them in
231+
# memory planning algo, instead collect all of the specs over all the entry
232+
# points and then allocate them directly in the run_multimethod name call
233+
self.alloc_mutable_buffers and not self.share_mutable_buffers,
138234
)
139235

236+
if self.share_mutable_buffers and graph_signature is not None:
237+
self.state.graph_modules.append(graph_module)
238+
_check_default_mem_ids(graph_module)
239+
_insert_mutable_buffer_specs(self.state, graph_module, graph_signature)
240+
140241
# TODO: make the verifier do the work recursively to handle
141242
# control flow
142243
verifier = Verifier(
@@ -164,3 +265,31 @@ def run(
164265
# I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
165266
verifier.verify_storage_reuse()
166267
return PassResult(graph_module, True)
268+
269+
def run_multimethod(self):
270+
"Resolve any memory planning done across entry points"
271+
if self.share_mutable_buffers:
272+
arena: int = 0
273+
274+
# Every spec that shares an fqn is the same tensor! So we give it the same id and offset
275+
# anywhere it appears.
276+
for fqn, specs_set in self.state.mutable_buffers.items():
277+
specs = list(specs_set)
278+
# If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
279+
if fqn in self.state.maybe_mutable_buffers.keys():
280+
specs.extend(self.state.maybe_mutable_buffers[fqn])
281+
for spec in specs:
282+
# Assume a default memory planning placed all activations on 1, place shared state on 2.
283+
spec.mem_id = 2
284+
spec.realign(self.alignment)
285+
# State is persistent, so the memory never overlaps.
286+
spec.mem_offset = arena
287+
# They should all be the same size since they are the same tensor, so just bump off the first.
288+
arena += specs[0].allocated_memory
289+
290+
for graph_module in self.state.graph_modules:
291+
if len(graph_module.meta['non_const_buffer_sizes']) != 2:
292+
raise ValueError("Cannot share mutable state if not using default memory ids")
293+
graph_module.meta['non_const_buffer_sizes'].append(arena)
294+
295+

exir/program/_program.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,11 +1745,11 @@ def to_executorch(
17451745
memory_planning_pass = config.memory_planning_pass
17461746
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
17471747
if hasattr(memory_planning_pass, "run"):
1748-
new_gm_res = memory_planning_pass.run( # pyre-ignore[16]
1748+
new_gm_res = memory_planning_pass.run(
17491749
new_gm, new_signature
17501750
)
17511751
else:
1752-
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
1752+
new_gm_res = memory_planning_pass(new_gm)
17531753

17541754
# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
17551755
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
@@ -1758,6 +1758,16 @@ def to_executorch(
17581758

17591759
_copy_module(program.graph_module, new_gm)
17601760
execution_programs[name] = program
1761+
# After running memory planning on all entry points we can run the cross entry point memory planning
1762+
if isinstance(config.memory_planning_pass, dict):
1763+
for memory_planning_pass in config.memory_planning_pass.values():
1764+
if hasattr(memory_planning_pass, "run_multimethod"):
1765+
memory_planning_pass.run_multimethod()
1766+
else:
1767+
memory_planning_pass = config.memory_planning_pass
1768+
if hasattr(memory_planning_pass, "run_multimethod"):
1769+
memory_planning_pass.run_multimethod()
1770+
17611771

17621772
et_pm = ExecutorchProgramManager(
17631773
execution_programs,

exir/tests/test_memory_planning.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from executorch.exir import ExecutorchBackendConfig, to_edge
17+
from executorch.exir.capture._capture import patch_forward
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.memory_planning import (
1920
_do_user_inputs_exist,
@@ -33,6 +34,7 @@
3334
ToOutVarPass,
3435
)
3536
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
37+
from executorch.exir.schema import KernelTypes
3638
from executorch.exir.tensor import TensorSpec
3739
from functorch.experimental.control_flow import map as torch_map
3840
from parameterized import parameterized
@@ -59,6 +61,7 @@
5961
from torch.fx import Graph, GraphModule, Node
6062
from torch.nn import functional as F
6163
from torch.utils import _pytree as pytree
64+
from torchao.quantization.pt2e.export_utils import WrapperModule
6265

6366
torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib")
6467

@@ -92,6 +95,23 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
9295
def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
9396
return (torch.randn(10), torch.randn(10))
9497

98+
class MultiEntryPointStatefulModel(torch.nn.Module):
99+
def __init__(self) -> None:
100+
super().__init__()
101+
self.register_buffer("state", torch.zeros(2,2))
102+
103+
def forward(self, x: torch.Tensor) -> torch.Tensor:
104+
return self.state.add_(x).view(-1) * 2
105+
106+
def set_state(self, state: torch.Tensor) -> None:
107+
self.state.copy_(state)
108+
109+
def get_state(self) -> torch.Tensor:
110+
return self.state
111+
112+
def get_example_inputs(self) -> Tuple[torch.Tensor, ...]:
113+
return (torch.ones(1),)
114+
95115

96116
class ModelWithDifferentTensorSizes(torch.nn.Module):
97117
def __init__(self) -> None:
@@ -1081,3 +1101,22 @@ def test_multi_map(self) -> None:
10811101
verifier.storage_overlap(outer_spec, inner_spec),
10821102
f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap",
10831103
)
1104+
def test_multi_state_plan(self) -> None:
1105+
eager_module = MultiEntryPointStatefulModel().eval()
1106+
forward = export(eager_module, eager_module.get_example_inputs())
1107+
with patch_forward(eager_module, eager_module.get_state):
1108+
get_state = export(eager_module, ())
1109+
with patch_forward(eager_module, eager_module.set_state):
1110+
set_state = export(eager_module, (torch.zeros(1),))
1111+
edge = to_edge({"forward": forward, "set_state": set_state, "get_state": get_state})
1112+
et = edge.to_executorch(ExecutorchBackendConfig(memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True), emit_mutable_buffer_names=True))
1113+
et_prog = et.executorch_program
1114+
count = 0
1115+
for plan in et_prog.execution_plan:
1116+
for value in plan.values:
1117+
if hasattr(value.val, "allocation_info") and value.val.allocation_info is not None and value.val.allocation_info.memory_id == 2:
1118+
count += 1
1119+
self.assertEqual(value.val.allocation_info.memory_offset_low, 0)
1120+
self.assertTrue(value.val.extra_tensor_info is not None)
1121+
self.assertEqual(value.val.extra_tensor_info.fully_qualified_name, "state")
1122+
self.assertEqual(count, 3)

0 commit comments

Comments
 (0)