Skip to content

Commit 9da6f8c

Browse files
hsharma35facebook-github-bot
authored andcommitted
Improve memory planning for submodule hierarchies. (pytorch#11860)
Summary: Improves the memory planning across hierarchies in apply_algo in memory_planning.py: 1. Plan memory bottom-to-top, starting with the leaf submodules and ending at top-level graph module (root). This is now consistent with how delegates are compiled / memory planned. Future PRs/diffs will add support for planned buffers in delegates. 2. Allocate max bufsize for all submodules as `graph_module.meta['input_mem_buffer_sizes']`, rather than sum. This allows us to reclaim the space used by one submodule for another submodule. Before this change the apply_algo in memory_planning.py would: 1. Plan memory top-to-bottom, starting with the top-level graph module (root). 2. Populate the `input_mem_buffer_sizes` so that each new submodule will allocate memory after the max buffer size of previous memory. For example: ``` root [A bytes] - root.child0 [B bytes] - root.child0.child0 [C bytes] - root.child1 [D bytes] ``` (before this diff) Planned memory looks like: ``` --- A + B + C + D ---------------- Space for root.child1 --- A + B + C -------------------- Space for root.child0.child0 --- A + B ------------------------ Space for root.child0 --- A ---------------------------- Space for root --- 0 ---------------------------- ``` Note that tensors for child0 and child1 have no overlap but still use completely different space. (after this diff) Planned memory looks like: ``` --- max(C + B, D) + A ---------- root --- max(C + B, D) -------------- root.child0 | --- C ------------ | root.child1 root.child0.child0 | --- 0 -------------------------- ``` Note: We can update memory planning algo to plan nodes with submodules (while/map/cond or even delegate) to use `graph_module.meta['non_const_buffer_size']` and reduce space even further. Implementation for this is not part of this PR/Diff. This will allow us to reuse space for `root.child0.child0` in `root.child0`, and space for `root.child0`/`root.child1` in `root. Differential Revision: D76940237
1 parent f072e64 commit 9da6f8c

File tree

2 files changed

+232
-60
lines changed

2 files changed

+232
-60
lines changed

exir/memory_planning.py

Lines changed: 134 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,20 @@
1010
import itertools
1111
import logging
1212
import operator
13-
import typing
1413
from collections import defaultdict
1514
from dataclasses import dataclass, field
16-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
15+
from typing import (
16+
Any,
17+
Callable,
18+
cast,
19+
Dict,
20+
Iterable,
21+
List,
22+
Optional,
23+
Set,
24+
Tuple,
25+
Union,
26+
)
1727

1828
import torch
1929
from executorch.exir import memory
@@ -949,7 +959,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
949959
bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
950960
if bufsizes is None:
951961
bufsizes = [0, 0]
952-
bufsizes = typing.cast(List[int], bufsizes)
962+
bufsizes = cast(List[int], bufsizes)
953963

954964
for spec in specs:
955965
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
@@ -1051,92 +1061,158 @@ def insert_calls_to_free(
10511061
graph_module.recompile()
10521062

10531063

1064+
def _merge_bufsizes(bufsizes: list[int], new_bufsizes: list[int]) -> list[int]:
1065+
"""Combine two buffer size lists."""
1066+
if len(bufsizes) < len(new_bufsizes):
1067+
bufsizes.extend([0] * (len(new_bufsizes) - len(bufsizes)))
1068+
for i in range(len(new_bufsizes)):
1069+
bufsizes[i] = max(bufsizes[i], new_bufsizes[i])
1070+
return bufsizes
1071+
1072+
1073+
def _handle_submodule(
1074+
algo: Callable[..., list[int]],
1075+
parent_graph_module: torch.fx.GraphModule,
1076+
alignment: int,
1077+
submodule_node: torch.fx.Node,
1078+
graph_signature: Optional[ExportGraphSignature] = None,
1079+
alloc_graph_input: bool = False,
1080+
) -> list[int]:
1081+
"""Apply algo to nodes in a submodule of the graph module."""
1082+
assert submodule_node.op == "get_attr"
1083+
submodule = getattr(parent_graph_module, submodule_node.target)
1084+
1085+
logging.debug(f"Planning memory for submodule {submodule_node.name}...")
1086+
bufsizes = apply_algo(
1087+
algo,
1088+
submodule,
1089+
alignment,
1090+
graph_signature,
1091+
alloc_graph_input=alloc_graph_input,
1092+
alloc_graph_output=True,
1093+
)
1094+
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
1095+
logging.debug(f"Buffer sizes for submodule {submodule_node.name}: {bufsizes}")
1096+
return bufsizes
1097+
1098+
1099+
def _apply_algo_to_submodules(
1100+
algo: Callable[..., list[int]],
1101+
graph_module: torch.fx.GraphModule,
1102+
alignment: int,
1103+
graph_signature: Optional[ExportGraphSignature] = None,
1104+
) -> list[int]:
1105+
"""Apply algo to map/cond/while nodes in the graph module.
1106+
1107+
This method will popuate graph_module.meta["non_const_buffer_sizes"] for
1108+
all submodules and return a bufsizes list that is the maximum size of all
1109+
buffers.
1110+
"""
1111+
1112+
# Bufsizes for submodules.
1113+
bufsizes: list[int] = []
1114+
1115+
def _handle(
1116+
submodule_node: torch.fx.Node,
1117+
alloc_graph_input: bool = False,
1118+
) -> None:
1119+
current_bufsizes = _handle_submodule(
1120+
algo,
1121+
graph_module,
1122+
alignment,
1123+
submodule_node,
1124+
graph_signature,
1125+
alloc_graph_input=alloc_graph_input,
1126+
)
1127+
nonlocal bufsizes
1128+
_merge_bufsizes(bufsizes, current_bufsizes)
1129+
1130+
for cond_node in get_cond_nodes(graph_module):
1131+
_handle(cast(torch.fx.Node, cond_node.args[1]))
1132+
_handle(cast(torch.fx.Node, cond_node.args[2]))
1133+
1134+
for while_node in get_while_nodes(graph_module):
1135+
_handle(cast(torch.fx.Node, while_node.args[0]))
1136+
_handle(cast(torch.fx.Node, while_node.args[1]))
1137+
1138+
for map_node in get_map_nodes(graph_module):
1139+
_handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True)
1140+
1141+
# TODO: We can handle delegates the same way as map/cond/while.
1142+
# Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates.
1143+
1144+
return bufsizes
1145+
1146+
10541147
def apply_algo(
1055-
algo: Callable[
1056-
...,
1057-
List[int],
1058-
],
1148+
algo: Callable[..., list[int]],
10591149
graph_module: torch.fx.GraphModule,
10601150
alignment: int,
10611151
graph_signature: Optional[ExportGraphSignature] = None,
10621152
alloc_graph_input: bool = True,
10631153
alloc_graph_output: bool = True,
10641154
alloc_mutable_buffers: bool = True,
1065-
) -> List[int]:
1155+
) -> list[int]:
10661156
"""
10671157
Recursively apply algo to graph_module and its submodules for control flow.
10681158
1069-
Quite naively right now since it does not take the following optimizations
1070-
into considerating:
1071-
1. for conditional structure, true branch and false true does not overlap
1072-
in lifetime and can share tensor storage
1073-
2. tensors inside a submodule (e.g. true branch) has opportunities to share
1074-
storage with tensors in the outer module.
1075-
TODO: make these optimizations once we have some baseline working.
1159+
Algo implementation should handle one of two meta entries for submodules:
1160+
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
1161+
`algo` should start at the offset specified by this list;
1162+
OR
1163+
2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule.
1164+
`algo` should reserve the space specified by this list for the lifetime
1165+
of the submodule node (e.g. cond, while, map).
1166+
1167+
TODO: Missing optimizations:
1168+
1. To handle maps, we set `alloc_graph_input=True`, which allocates
1169+
appropriate space for mapped arg but ends up allocating extra space for
1170+
`operand` arg. The memory for operands is unused.
10761171
"""
10771172
# Extract the nodes and their lifespans from the graph_module
10781173
# Difficult to just filter the list of specs returned by this due to
10791174
# how we flag trainable weights.
10801175
_ = update_all_tensors_lifetime(graph_module, graph_signature)
1176+
10811177
# Filter specs based on alloc_graph_input and alloc_graph_output
1082-
specs = collect_specs_from_nodes(
1083-
graph_module.graph.nodes,
1084-
graph_signature,
1085-
do_assertion=False,
1086-
ignore_graph_input=not alloc_graph_input,
1087-
ignore_graph_output=not alloc_graph_output,
1088-
ignore_mutable_buffers=not alloc_mutable_buffers,
1178+
specs = set(
1179+
collect_specs_from_nodes(
1180+
graph_module.graph.nodes,
1181+
graph_signature,
1182+
do_assertion=False,
1183+
ignore_graph_input=not alloc_graph_input,
1184+
ignore_graph_output=not alloc_graph_output,
1185+
ignore_mutable_buffers=not alloc_mutable_buffers,
1186+
)
10891187
)
10901188

1189+
# Get temporary specs for submodules to set aside space during execution
1190+
# of submodules.
1191+
submodule_bufsizes = _apply_algo_to_submodules(
1192+
algo, graph_module, alignment, graph_signature
1193+
)
1194+
1195+
# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
1196+
# algos to work using `input_mem_buffer_sizes` or use
1197+
# `non_const_buffer_sizes` directly.
1198+
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1199+
graph_module.input_mem_buffer_sizes = submodule_bufsizes
1200+
10911201
# Get extra padding for XNNPACK if needed
10921202
extra_padding = 0
10931203
if _contains_xnnpack_delegate(graph_module):
10941204
extra_padding = 64
10951205

10961206
# Pass the filtered specs to the algorithm
1097-
bufsizes: List[int] = algo(
1207+
bufsizes: list[int] = algo(
10981208
alignment,
10991209
specs,
11001210
graph_module,
11011211
graph_signature,
11021212
extra_padding,
11031213
)
11041214

1105-
insert_calls_to_free(graph_module, set(specs))
1106-
1107-
def handle_submodule(
1108-
submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
1109-
) -> None:
1110-
nonlocal bufsizes
1111-
assert submodule_nd.op == "get_attr"
1112-
submodule = getattr(graph_module, submodule_nd.target)
1113-
# memory planning for submodule need to be aware of the amount of
1114-
# buffer already allocated.
1115-
submodule.input_mem_buffer_sizes = bufsizes
1116-
1117-
bufsizes = apply_algo(
1118-
algo,
1119-
submodule,
1120-
alignment,
1121-
graph_signature,
1122-
alloc_graph_input=alloc_graph_input,
1123-
alloc_graph_output=True,
1124-
)
1125-
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
1126-
1127-
for cond_node in get_cond_nodes(graph_module):
1128-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
1129-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))
1130-
1131-
for while_node in get_while_nodes(graph_module):
1132-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
1133-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
1134-
# TODO: Add test coverage for map operator once dynamo tracing is
1135-
# fully supported for this. T142287208
1136-
for map_node in get_map_nodes(graph_module):
1137-
handle_submodule(
1138-
typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
1139-
)
1215+
insert_calls_to_free(graph_module, specs)
11401216

11411217
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
11421218
return bufsizes

exir/tests/test_memory_planning.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
ToOutVarPass,
3333
)
3434
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
35+
from executorch.exir.tensor import TensorSpec
36+
from functorch.experimental.control_flow import map as torch_map
3537
from parameterized import parameterized
3638

3739
from torch import nn
@@ -55,6 +57,7 @@
5557
from torch.export.exported_program import ExportGraphSignature
5658
from torch.fx import Graph, GraphModule, Node
5759
from torch.nn import functional as F
60+
from torch.utils import _pytree as pytree
5861

5962
torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib")
6063

@@ -420,13 +423,13 @@ def test_graph_input_output(self) -> None:
420423
alloc_graph_output,
421424
alloc_mutable_buffers,
422425
) in itertools.product([True, False], [True, False], [True, False]):
423-
case = maketest(
426+
test = maketest(
424427
ModelWithDifferentTensorSizes,
425428
alloc_graph_input=alloc_graph_input,
426429
alloc_graph_output=alloc_graph_output,
427430
alloc_mutable_buffer=alloc_mutable_buffers,
428431
)
429-
case(self)
432+
test(self)
430433

431434

432435
class TestVerifier(unittest.TestCase):
@@ -788,3 +791,96 @@ def forward(self, input, label):
788791
.val.allocation_info, # pyright: ignore
789792
None,
790793
)
794+
795+
def _get_specs(gm: torch.fx.GraphModule) -> list[TensorSpec]:
796+
return list(
797+
filter(
798+
None,
799+
pytree.tree_flatten(
800+
pytree.tree_map_only(
801+
torch.fx.Node,
802+
lambda n: n.meta.get("spec", None),
803+
list(gm.graph.nodes),
804+
)
805+
)[0],
806+
)
807+
)
808+
809+
class TestMap(unittest.TestCase):
810+
class MapModel(torch.nn.Module):
811+
def __init__(self) -> None:
812+
super().__init__()
813+
814+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
815+
# Use actual torch.map function for memory planning testing
816+
def add_fn(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
817+
return a + b
818+
819+
# Use torch.map to apply function over first dimension
820+
# pyre-ignore[6]: For 3rd argument expected `TypeVarTuple` but got `Tensor`.
821+
map_output = torch_map(add_fn, x, y)
822+
823+
return map_output + y
824+
825+
def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
826+
return (torch.randn(5, 3), torch.randn(3))
827+
828+
def test_map(self) -> None:
829+
"""Test memory planning for torch.map operations."""
830+
831+
eager_module = self.MapModel().eval()
832+
inputs = eager_module.get_random_inputs()
833+
834+
# Export and convert to edge
835+
graph_module = (
836+
to_edge(export(eager_module, inputs, strict=True))
837+
.exported_program()
838+
.graph_module
839+
)
840+
841+
# Apply memory planning.
842+
mem_algo = MemoryPlanningAlgorithmSuite(algo_list=[naive, greedy])
843+
graph_module = PassManager(
844+
passes=[
845+
SpecPropPass(),
846+
ToOutVarPass(),
847+
],
848+
)(graph_module).graph_module
849+
mem_planning_pass = MemoryPlanningPass(
850+
mem_algo,
851+
alloc_graph_input=True,
852+
alloc_graph_output=True,
853+
alloc_mutable_buffers=True,
854+
)
855+
graph_module = mem_planning_pass.run(graph_module).graph_module
856+
857+
# Verify memory planning results
858+
verifier = Verifier(
859+
graph_module,
860+
alloc_graph_input=True,
861+
alloc_graph_output=True,
862+
alloc_mutable_buffers=True,
863+
)
864+
verifier.verify_graph_input_output()
865+
verifier.verify_storage_reuse(allow_lifetime_and_storage_overlap=False)
866+
867+
map_node = graph_module.graph.find_nodes(
868+
op="call_function", target=torch.ops.higher_order.map_impl
869+
)[0]
870+
map_fn_node = map_node.args[0]
871+
self.assertEqual(map_fn_node.op, "get_attr")
872+
map_fn = getattr(graph_module, map_fn_node.target)
873+
874+
875+
map_lifetime = map_node.meta.get("spec", None)[0].lifetime[0]
876+
877+
# Check that there is no storage overlap between nodes of the outer program and submodule of map.
878+
for outer_spec in _get_specs(graph_module):
879+
for inner_spec in _get_specs(map_fn):
880+
self.assertFalse(
881+
verifier.has_overlap(
882+
outer_spec.lifetime, [map_lifetime, map_lifetime]
883+
)
884+
and (verifier.storage_overlap(outer_spec, inner_spec)),
885+
f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap",
886+
)

0 commit comments

Comments
 (0)