Skip to content

Commit f632fa6

Browse files
hsharma35facebook-github-bot
authored andcommitted
Improve memory planning for submodule hierarchies. (#11860)
Summary: Improves the memory planning across hierarchies in apply_algo in memory_planning.py: 1. Plan memory bottom-to-top, starting with the leaf submodules and ending at top-level graph module (root). This is now consistent with how delegates are compiled / memory planned. Future PRs/diffs will add support for planned buffers in delegates. 2. Allocate max bufsize for all submodules as `graph_module.meta['input_mem_buffer_sizes']`, rather than sum. This allows us to reclaim the space used by one submodule for another submodule. Before this change the apply_algo in memory_planning.py would: 1. Plan memory top-to-bottom, starting with the top-level graph module (root). 2. Populate the `input_mem_buffer_sizes` so that each new submodule will allocate memory after the max buffer size of previous memory. For example: ``` root [A bytes] - root.child0 [B bytes] - root.child0.child0 [C bytes] - root.child1 [D bytes] ``` (before this diff) Planned memory looks like: ``` --- A + B + C + D ---------------- Space for root.child1 --- A + B + C -------------------- Space for root.child0.child0 --- A + B ------------------------ Space for root.child0 --- A ---------------------------- Space for root --- 0 ---------------------------- ``` Note that tensors for child0 and child1 have no overlap but still use completely different space. (after this diff) Planned memory looks like: ``` --- max(C + B, D) + A ---------- root --- max(C + B, D) -------------- root.child0 | --- C ------------ | root.child1 root.child0.child0 | --- 0 -------------------------- ``` Note: We can update memory planning algo to plan nodes with submodules (while/map/cond or even delegate) to use `graph_module.meta['non_const_buffer_size']` and reduce space even further. Implementation for this is not part of this PR/Diff. This will allow us to reuse space for `root.child0.child0` in `root.child0`, and space for `root.child0`/`root.child1` in `root. Differential Revision: D76940237
1 parent 9ee10b8 commit f632fa6

File tree

2 files changed

+233
-60
lines changed

2 files changed

+233
-60
lines changed

exir/memory_planning.py

Lines changed: 134 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,20 @@
1010
import itertools
1111
import logging
1212
import operator
13-
import typing
1413
from collections import defaultdict
1514
from dataclasses import dataclass, field
16-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
15+
from typing import (
16+
Any,
17+
Callable,
18+
cast,
19+
Dict,
20+
Iterable,
21+
List,
22+
Optional,
23+
Set,
24+
Tuple,
25+
Union,
26+
)
1727

1828
import torch
1929
from executorch.exir import memory
@@ -960,7 +970,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
960970
bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
961971
if bufsizes is None:
962972
bufsizes = [0, 0]
963-
bufsizes = typing.cast(List[int], bufsizes)
973+
bufsizes = cast(List[int], bufsizes)
964974

965975
for spec in specs:
966976
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
@@ -1062,92 +1072,158 @@ def insert_calls_to_free(
10621072
graph_module.recompile()
10631073

10641074

1075+
def _merge_bufsizes(bufsizes: list[int], new_bufsizes: list[int]) -> list[int]:
1076+
"""Combine two buffer size lists."""
1077+
if len(bufsizes) < len(new_bufsizes):
1078+
bufsizes.extend([0] * (len(new_bufsizes) - len(bufsizes)))
1079+
for i in range(len(new_bufsizes)):
1080+
bufsizes[i] = max(bufsizes[i], new_bufsizes[i])
1081+
return bufsizes
1082+
1083+
1084+
def _handle_submodule(
1085+
algo: Callable[..., list[int]],
1086+
parent_graph_module: torch.fx.GraphModule,
1087+
alignment: int,
1088+
submodule_node: torch.fx.Node,
1089+
graph_signature: Optional[ExportGraphSignature] = None,
1090+
alloc_graph_input: bool = False,
1091+
) -> list[int]:
1092+
"""Apply algo to nodes in a submodule of the graph module."""
1093+
assert submodule_node.op == "get_attr"
1094+
submodule = getattr(parent_graph_module, submodule_node.target)
1095+
1096+
logging.debug(f"Planning memory for submodule {submodule_node.name}...")
1097+
bufsizes = apply_algo(
1098+
algo,
1099+
submodule,
1100+
alignment,
1101+
graph_signature,
1102+
alloc_graph_input=alloc_graph_input,
1103+
alloc_graph_output=True,
1104+
)
1105+
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
1106+
logging.debug(f"Buffer sizes for submodule {submodule_node.name}: {bufsizes}")
1107+
return bufsizes
1108+
1109+
1110+
def _apply_algo_to_submodules(
1111+
algo: Callable[..., list[int]],
1112+
graph_module: torch.fx.GraphModule,
1113+
alignment: int,
1114+
graph_signature: Optional[ExportGraphSignature] = None,
1115+
) -> list[int]:
1116+
"""Apply algo to map/cond/while nodes in the graph module.
1117+
1118+
This method will popuate graph_module.meta["non_const_buffer_sizes"] for
1119+
all submodules and return a bufsizes list that is the maximum size of all
1120+
buffers.
1121+
"""
1122+
1123+
# Bufsizes for submodules.
1124+
bufsizes: list[int] = []
1125+
1126+
def _handle(
1127+
submodule_node: torch.fx.Node,
1128+
alloc_graph_input: bool = False,
1129+
) -> None:
1130+
current_bufsizes = _handle_submodule(
1131+
algo,
1132+
graph_module,
1133+
alignment,
1134+
submodule_node,
1135+
graph_signature,
1136+
alloc_graph_input=alloc_graph_input,
1137+
)
1138+
nonlocal bufsizes
1139+
_merge_bufsizes(bufsizes, current_bufsizes)
1140+
1141+
for cond_node in get_cond_nodes(graph_module):
1142+
_handle(cast(torch.fx.Node, cond_node.args[1]))
1143+
_handle(cast(torch.fx.Node, cond_node.args[2]))
1144+
1145+
for while_node in get_while_nodes(graph_module):
1146+
_handle(cast(torch.fx.Node, while_node.args[0]))
1147+
_handle(cast(torch.fx.Node, while_node.args[1]))
1148+
1149+
for map_node in get_map_nodes(graph_module):
1150+
_handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True)
1151+
1152+
# TODO: We can handle delegates the same way as map/cond/while.
1153+
# Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates.
1154+
1155+
return bufsizes
1156+
1157+
10651158
def apply_algo(
1066-
algo: Callable[
1067-
...,
1068-
List[int],
1069-
],
1159+
algo: Callable[..., list[int]],
10701160
graph_module: torch.fx.GraphModule,
10711161
alignment: int,
10721162
graph_signature: Optional[ExportGraphSignature] = None,
10731163
alloc_graph_input: bool = True,
10741164
alloc_graph_output: bool = True,
10751165
alloc_mutable_buffers: bool = True,
1076-
) -> List[int]:
1166+
) -> list[int]:
10771167
"""
10781168
Recursively apply algo to graph_module and its submodules for control flow.
10791169
1080-
Quite naively right now since it does not take the following optimizations
1081-
into considerating:
1082-
1. for conditional structure, true branch and false true does not overlap
1083-
in lifetime and can share tensor storage
1084-
2. tensors inside a submodule (e.g. true branch) has opportunities to share
1085-
storage with tensors in the outer module.
1086-
TODO: make these optimizations once we have some baseline working.
1170+
Algo implementation should handle one of two meta entries for submodules:
1171+
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
1172+
`algo` should start at the offset specified by this list;
1173+
OR
1174+
2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule.
1175+
`algo` should reserve the space specified by this list for the lifetime
1176+
of the submodule node (e.g. cond, while, map).
1177+
1178+
TODO: Missing optimizations:
1179+
1. To handle maps, we set `alloc_graph_input=True`, which allocates
1180+
appropriate space for mapped arg but ends up allocating extra space for
1181+
`operand` arg. The memory for operands is unused.
10871182
"""
10881183
# Extract the nodes and their lifespans from the graph_module
10891184
# Difficult to just filter the list of specs returned by this due to
10901185
# how we flag trainable weights.
10911186
_ = update_all_tensors_lifetime(graph_module, graph_signature)
1187+
10921188
# Filter specs based on alloc_graph_input and alloc_graph_output
1093-
specs = collect_specs_from_nodes(
1094-
graph_module.graph.nodes,
1095-
graph_signature,
1096-
do_assertion=False,
1097-
ignore_graph_input=not alloc_graph_input,
1098-
ignore_graph_output=not alloc_graph_output,
1099-
ignore_mutable_buffers=not alloc_mutable_buffers,
1189+
specs = set(
1190+
collect_specs_from_nodes(
1191+
graph_module.graph.nodes,
1192+
graph_signature,
1193+
do_assertion=False,
1194+
ignore_graph_input=not alloc_graph_input,
1195+
ignore_graph_output=not alloc_graph_output,
1196+
ignore_mutable_buffers=not alloc_mutable_buffers,
1197+
)
11001198
)
11011199

1200+
# Get temporary specs for submodules to set aside space during execution
1201+
# of submodules.
1202+
submodule_bufsizes = _apply_algo_to_submodules(
1203+
algo, graph_module, alignment, graph_signature
1204+
)
1205+
1206+
# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
1207+
# algos to work using `input_mem_buffer_sizes` or use
1208+
# `non_const_buffer_sizes` directly.
1209+
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1210+
graph_module.input_mem_buffer_sizes = submodule_bufsizes
1211+
11021212
# Get extra padding for XNNPACK if needed
11031213
extra_padding = 0
11041214
if _contains_xnnpack_delegate(graph_module):
11051215
extra_padding = 64
11061216

11071217
# Pass the filtered specs to the algorithm
1108-
bufsizes: List[int] = algo(
1218+
bufsizes: list[int] = algo(
11091219
alignment,
11101220
specs,
11111221
graph_module,
11121222
graph_signature,
11131223
extra_padding,
11141224
)
11151225

1116-
insert_calls_to_free(graph_module, set(specs))
1117-
1118-
def handle_submodule(
1119-
submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
1120-
) -> None:
1121-
nonlocal bufsizes
1122-
assert submodule_nd.op == "get_attr"
1123-
submodule = getattr(graph_module, submodule_nd.target)
1124-
# memory planning for submodule need to be aware of the amount of
1125-
# buffer already allocated.
1126-
submodule.input_mem_buffer_sizes = bufsizes
1127-
1128-
bufsizes = apply_algo(
1129-
algo,
1130-
submodule,
1131-
alignment,
1132-
graph_signature,
1133-
alloc_graph_input=alloc_graph_input,
1134-
alloc_graph_output=True,
1135-
)
1136-
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
1137-
1138-
for cond_node in get_cond_nodes(graph_module):
1139-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
1140-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))
1141-
1142-
for while_node in get_while_nodes(graph_module):
1143-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
1144-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
1145-
# TODO: Add test coverage for map operator once dynamo tracing is
1146-
# fully supported for this. T142287208
1147-
for map_node in get_map_nodes(graph_module):
1148-
handle_submodule(
1149-
typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
1150-
)
1226+
insert_calls_to_free(graph_module, specs)
11511227

11521228
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
11531229
return bufsizes

exir/tests/test_memory_planning.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
ToOutVarPass,
3434
)
3535
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
36+
from executorch.exir.tensor import TensorSpec
37+
from functorch.experimental.control_flow import map as torch_map
3638
from parameterized import parameterized
3739

3840
from torch import nn
@@ -56,6 +58,7 @@
5658
from torch.export.exported_program import ExportGraphSignature
5759
from torch.fx import Graph, GraphModule, Node
5860
from torch.nn import functional as F
61+
from torch.utils import _pytree as pytree
5962

6063
torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib")
6164

@@ -471,13 +474,13 @@ def test_graph_input_output(self) -> None:
471474
alloc_graph_output,
472475
alloc_mutable_buffers,
473476
) in itertools.product([True, False], [True, False], [True, False]):
474-
case = maketest(
477+
test = maketest(
475478
ModelWithDifferentTensorSizes,
476479
alloc_graph_input=alloc_graph_input,
477480
alloc_graph_output=alloc_graph_output,
478481
alloc_mutable_buffer=alloc_mutable_buffers,
479482
)
480-
case(self)
483+
test(self)
481484

482485

483486
class TestVerifier(unittest.TestCase):
@@ -839,3 +842,97 @@ def forward(self, input, label):
839842
.val.allocation_info, # pyright: ignore
840843
None,
841844
)
845+
846+
847+
def _get_specs(gm: torch.fx.GraphModule) -> list[TensorSpec]:
848+
return list(
849+
filter(
850+
None,
851+
pytree.tree_flatten(
852+
pytree.tree_map_only(
853+
torch.fx.Node,
854+
lambda n: n.meta.get("spec", None),
855+
list(gm.graph.nodes),
856+
)
857+
)[0],
858+
)
859+
)
860+
861+
862+
class TestMap(unittest.TestCase):
863+
class MapModel(torch.nn.Module):
864+
def __init__(self) -> None:
865+
super().__init__()
866+
867+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
868+
# Use actual torch.map function for memory planning testing
869+
def add_fn(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
870+
return a + b
871+
872+
# Use torch.map to apply function over first dimension
873+
# pyre-ignore[6]: For 3rd argument expected `TypeVarTuple` but got `Tensor`.
874+
map_output = torch_map(add_fn, x, y)
875+
876+
return map_output + y
877+
878+
def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
879+
return (torch.randn(5, 3), torch.randn(3))
880+
881+
def test_map(self) -> None:
882+
"""Test memory planning for torch.map operations."""
883+
884+
eager_module = self.MapModel().eval()
885+
inputs = eager_module.get_random_inputs()
886+
887+
# Export and convert to edge
888+
graph_module = (
889+
to_edge(export(eager_module, inputs, strict=True))
890+
.exported_program()
891+
.graph_module
892+
)
893+
894+
# Apply memory planning.
895+
mem_algo = MemoryPlanningAlgorithmSuite(algo_list=[naive, greedy])
896+
graph_module = PassManager(
897+
passes=[
898+
SpecPropPass(),
899+
ToOutVarPass(),
900+
],
901+
)(graph_module).graph_module
902+
mem_planning_pass = MemoryPlanningPass(
903+
mem_algo,
904+
alloc_graph_input=True,
905+
alloc_graph_output=True,
906+
alloc_mutable_buffers=True,
907+
)
908+
graph_module = mem_planning_pass.run(graph_module).graph_module
909+
910+
# Verify memory planning results
911+
verifier = Verifier(
912+
graph_module,
913+
alloc_graph_input=True,
914+
alloc_graph_output=True,
915+
alloc_mutable_buffers=True,
916+
)
917+
verifier.verify_graph_input_output()
918+
verifier.verify_storage_reuse(allow_lifetime_and_storage_overlap=False)
919+
920+
map_node = graph_module.graph.find_nodes(
921+
op="call_function", target=torch.ops.higher_order.map_impl
922+
)[0]
923+
map_fn_node = map_node.args[0]
924+
self.assertEqual(map_fn_node.op, "get_attr")
925+
map_fn = getattr(graph_module, map_fn_node.target)
926+
927+
map_lifetime = map_node.meta.get("spec", None)[0].lifetime[0]
928+
929+
# Check that there is no storage overlap between nodes of the outer program and submodule of map.
930+
for outer_spec in _get_specs(graph_module):
931+
for inner_spec in _get_specs(map_fn):
932+
self.assertFalse(
933+
verifier.has_overlap(
934+
outer_spec.lifetime, [map_lifetime, map_lifetime]
935+
)
936+
and (verifier.storage_overlap(outer_spec, inner_spec)),
937+
f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap",
938+
)

0 commit comments

Comments
 (0)