Skip to content

Commit 98bbc9b

Browse files
hsharma35facebook-github-bot
authored andcommitted
Improve memory planning for submodule hierarchies.
Summary: Improves the memory planning across hierarchies in apply_algo in memory_planning.py: 1. Plan memory bottom-to-top, starting with the leaf submodules and ending at top-level graph module (root). This is now consistent with how delegates are compiled / memory planned. Future PRs/diffs will add support for planned buffers in delegates. 2. Allocate max bufsize for all submodules as `graph_module.meta['input_mem_buffer_sizes']`, rather than sum. This allows us to reclaim the space used by one submodule for another submodule. Before this change the apply_algo in memory_planning.py would: 1. Plan memory top-to-bottom, starting with the top-level graph module (root). 2. Populate the `input_mem_buffer_sizes` so that each new submodule will allocate memory after the max buffer size of previous memory. For example: ``` root [A bytes] - root.child0 [B bytes] - root.child0.child0 [C bytes] - root.child1 [D bytes] ``` (before this diff) Planned memory looks like: ``` --- A + B + C + D ---------------- Space for root.child1 --- A + B + C -------------------- Space for root.child0.child0 --- A + B ------------------------ Space for root.child0 --- A ---------------------------- Space for root --- 0 ---------------------------- ``` Note that tensors for child0 and child1 have no overlap but still use completely different space. (after this diff) Planned memory looks like: ``` --- max(C + B, D) + A ---------- root --- max(C + B, D) -------------- root.child0 | --- C ------------ | root.child1 root.child0.child0 | --- 0 -------------------------- ``` Note: We can update memory planning algo to plan nodes with submodules (while/map/cond or even delegate) to use `graph_module.meta['non_const_buffer_size']` and reduce space even further. Implementation for this is not part of this PR/Diff. This will allow us to reuse space for `root.child0.child0` in `root.child0`, and space for `root.child0`/`root.child1` in `root. Differential Revision: D76940237
1 parent 7f2fcb0 commit 98bbc9b

File tree

1 file changed

+131
-58
lines changed

1 file changed

+131
-58
lines changed

exir/memory_planning.py

Lines changed: 131 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,20 @@
1010
import itertools
1111
import logging
1212
import operator
13-
import typing
1413
from collections import defaultdict
1514
from dataclasses import dataclass, field
16-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
15+
from typing import (
16+
Any,
17+
Callable,
18+
cast,
19+
Dict,
20+
Iterable,
21+
List,
22+
Optional,
23+
Set,
24+
Tuple,
25+
Union,
26+
)
1727

1828
import torch
1929
from executorch.exir import memory
@@ -949,7 +959,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
949959
bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
950960
if bufsizes is None:
951961
bufsizes = [0, 0]
952-
bufsizes = typing.cast(List[int], bufsizes)
962+
bufsizes = cast(List[int], bufsizes)
953963

954964
for spec in specs:
955965
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
@@ -1051,92 +1061,155 @@ def insert_calls_to_free(
10511061
graph_module.recompile()
10521062

10531063

1064+
def _merge_bufsizes(bufsizes: list[int], new_bufsizes: list[int]) -> list[int]:
1065+
"""Combine two buffer size lists."""
1066+
if len(bufsizes) < len(new_bufsizes):
1067+
bufsizes.extend([0] * (len(new_bufsizes) - len(bufsizes)))
1068+
for i in range(len(new_bufsizes)):
1069+
bufsizes[i] = max(bufsizes[i], new_bufsizes[i])
1070+
return bufsizes
1071+
1072+
1073+
def _handle_submodule(
1074+
algo: Callable[..., list[int]],
1075+
parent_graph_module: torch.fx.GraphModule,
1076+
alignment: int,
1077+
submodule_node: torch.fx.Node,
1078+
graph_signature: Optional[ExportGraphSignature] = None,
1079+
alloc_graph_input: bool = False,
1080+
) -> list[int]:
1081+
"""Apply algo to nodes in a submodule of the graph module."""
1082+
assert submodule_node.op == "get_attr"
1083+
submodule = getattr(parent_graph_module, submodule_node.target)
1084+
1085+
logging.debug(f"Planning memory for submodule {submodule_node.name}...")
1086+
bufsizes = apply_algo(
1087+
algo,
1088+
submodule,
1089+
alignment,
1090+
graph_signature,
1091+
alloc_graph_input=alloc_graph_input,
1092+
alloc_graph_output=True,
1093+
)
1094+
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
1095+
logging.debug(f"Buffer sizes for submodule {submodule_node.name}: {bufsizes}")
1096+
return bufsizes
1097+
1098+
1099+
def _apply_algo_to_submodules(
1100+
algo: Callable[..., list[int]],
1101+
graph_module: torch.fx.GraphModule,
1102+
alignment: int,
1103+
graph_signature: Optional[ExportGraphSignature] = None,
1104+
) -> list[int]:
1105+
"""Apply algo to map/cond/while nodes in the graph module.
1106+
1107+
This method will popuate graph_module.meta["non_const_buffer_sizes"] for
1108+
all submodules and return a bufsizes list that is the maximum size of all
1109+
buffers.
1110+
"""
1111+
1112+
# Bufsizes for submodules.
1113+
bufsizes: list[int] = []
1114+
1115+
def _handle(
1116+
submodule_node: torch.fx.Node,
1117+
alloc_graph_input: bool = False,
1118+
) -> None:
1119+
current_bufsizes = _handle_submodule(
1120+
algo,
1121+
graph_module,
1122+
alignment,
1123+
submodule_node,
1124+
graph_signature,
1125+
alloc_graph_input=alloc_graph_input,
1126+
)
1127+
nonlocal bufsizes
1128+
_merge_bufsizes(bufsizes, current_bufsizes)
1129+
1130+
for cond_node in get_cond_nodes(graph_module):
1131+
_handle(cast(torch.fx.Node, cond_node.args[1]))
1132+
_handle(cast(torch.fx.Node, cond_node.args[2]))
1133+
1134+
for while_node in get_while_nodes(graph_module):
1135+
_handle(cast(torch.fx.Node, while_node.args[0]))
1136+
_handle(cast(torch.fx.Node, while_node.args[1]))
1137+
1138+
for map_node in get_map_nodes(graph_module):
1139+
_handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True)
1140+
1141+
# TODO: We can handle delegates the same way as map/cond/while.
1142+
# Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates.
1143+
1144+
return bufsizes
1145+
1146+
10541147
def apply_algo(
1055-
algo: Callable[
1056-
...,
1057-
List[int],
1058-
],
1148+
algo: Callable[..., list[int]],
10591149
graph_module: torch.fx.GraphModule,
10601150
alignment: int,
10611151
graph_signature: Optional[ExportGraphSignature] = None,
10621152
alloc_graph_input: bool = True,
10631153
alloc_graph_output: bool = True,
10641154
alloc_mutable_buffers: bool = True,
1065-
) -> List[int]:
1155+
) -> list[int]:
10661156
"""
10671157
Recursively apply algo to graph_module and its submodules for control flow.
10681158
1069-
Quite naively right now since it does not take the following optimizations
1070-
into considerating:
1071-
1. for conditional structure, true branch and false true does not overlap
1072-
in lifetime and can share tensor storage
1073-
2. tensors inside a submodule (e.g. true branch) has opportunities to share
1074-
storage with tensors in the outer module.
1075-
TODO: make these optimizations once we have some baseline working.
1159+
Algo implementation should handle one of two meta entries for submodules:
1160+
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
1161+
`algo` should start at the offset specified by this list;
1162+
OR
1163+
2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule.
1164+
`algo` should reserve the space specified by this list for the lifetime
1165+
of the submodule node (e.g. cond, while, map).
1166+
1167+
TODO: Missing optimizations:
1168+
1. To handle maps, we set `alloc_graph_input=True`, which allocates
1169+
appropriate space for mapped arg but ends up allocating extra space for
1170+
`operand` arg. The memory for operands is unused.
10761171
"""
10771172
# Extract the nodes and their lifespans from the graph_module
10781173
# Difficult to just filter the list of specs returned by this due to
10791174
# how we flag trainable weights.
10801175
_ = update_all_tensors_lifetime(graph_module, graph_signature)
1176+
10811177
# Filter specs based on alloc_graph_input and alloc_graph_output
1082-
specs = collect_specs_from_nodes(
1083-
graph_module.graph.nodes,
1084-
graph_signature,
1085-
do_assertion=False,
1086-
ignore_graph_input=not alloc_graph_input,
1087-
ignore_graph_output=not alloc_graph_output,
1088-
ignore_mutable_buffers=not alloc_mutable_buffers,
1178+
specs = set(
1179+
collect_specs_from_nodes(
1180+
graph_module.graph.nodes,
1181+
graph_signature,
1182+
do_assertion=False,
1183+
ignore_graph_input=not alloc_graph_input,
1184+
ignore_graph_output=not alloc_graph_output,
1185+
ignore_mutable_buffers=not alloc_mutable_buffers,
1186+
)
10891187
)
10901188

1189+
# Get temporary specs for submodules to set aside space during execution
1190+
# of submodules.
1191+
submodule_bufsizes = _apply_algo_to_submodules(algo, graph_module, alignment, graph_signature)
1192+
1193+
# Update `input_mem_buffer_sizes` in graph_module.meta. This will allow
1194+
# existing algos to work using `input_mem_buffer_sizes` or use
1195+
# `non_const_buffer_sizes` directly.
1196+
graph_module.meta.update({"input_mem_buffer_sizes": submodule_bufsizes})
1197+
10911198
# Get extra padding for XNNPACK if needed
10921199
extra_padding = 0
10931200
if _contains_xnnpack_delegate(graph_module):
10941201
extra_padding = 64
10951202

10961203
# Pass the filtered specs to the algorithm
1097-
bufsizes: List[int] = algo(
1204+
bufsizes: list[int] = algo(
10981205
alignment,
10991206
specs,
11001207
graph_module,
11011208
graph_signature,
11021209
extra_padding,
11031210
)
11041211

1105-
insert_calls_to_free(graph_module, set(specs))
1106-
1107-
def handle_submodule(
1108-
submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
1109-
) -> None:
1110-
nonlocal bufsizes
1111-
assert submodule_nd.op == "get_attr"
1112-
submodule = getattr(graph_module, submodule_nd.target)
1113-
# memory planning for submodule need to be aware of the amount of
1114-
# buffer already allocated.
1115-
submodule.input_mem_buffer_sizes = bufsizes
1116-
1117-
bufsizes = apply_algo(
1118-
algo,
1119-
submodule,
1120-
alignment,
1121-
graph_signature,
1122-
alloc_graph_input=alloc_graph_input,
1123-
alloc_graph_output=True,
1124-
)
1125-
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
1126-
1127-
for cond_node in get_cond_nodes(graph_module):
1128-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
1129-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))
1130-
1131-
for while_node in get_while_nodes(graph_module):
1132-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
1133-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
1134-
# TODO: Add test coverage for map operator once dynamo tracing is
1135-
# fully supported for this. T142287208
1136-
for map_node in get_map_nodes(graph_module):
1137-
handle_submodule(
1138-
typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
1139-
)
1212+
insert_calls_to_free(graph_module, specs)
11401213

11411214
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
11421215
return bufsizes

0 commit comments

Comments
 (0)