Skip to content

Commit cb3b99a

Browse files
authored
Improve memory planning for submodule hierarchies.
Differential Revision: D76940237 Pull Request resolved: #11860
1 parent 4e59c4e commit cb3b99a

File tree

2 files changed

+332
-57
lines changed

2 files changed

+332
-57
lines changed

exir/memory_planning.py

Lines changed: 127 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,20 @@
1010
import itertools
1111
import logging
1212
import operator
13-
import typing
1413
from collections import defaultdict
1514
from dataclasses import dataclass, field
16-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
15+
from typing import (
16+
Any,
17+
Callable,
18+
cast,
19+
Dict,
20+
Iterable,
21+
List,
22+
Optional,
23+
Set,
24+
Tuple,
25+
Union,
26+
)
1727

1828
import torch
1929
from executorch.exir import memory
@@ -960,7 +970,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
960970
bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
961971
if bufsizes is None:
962972
bufsizes = [0, 0]
963-
bufsizes = typing.cast(List[int], bufsizes)
973+
bufsizes = cast(List[int], bufsizes)
964974

965975
for spec in specs:
966976
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
@@ -1062,33 +1072,119 @@ def insert_calls_to_free(
10621072
graph_module.recompile()
10631073

10641074

1075+
def _merge_bufsizes(bufsizes: list[int], new_bufsizes: list[int]) -> list[int]:
1076+
"""Combine two buffer size lists."""
1077+
if len(bufsizes) < len(new_bufsizes):
1078+
bufsizes.extend([0] * (len(new_bufsizes) - len(bufsizes)))
1079+
for i in range(len(new_bufsizes)):
1080+
bufsizes[i] = max(bufsizes[i], new_bufsizes[i])
1081+
return bufsizes
1082+
1083+
1084+
def _handle_submodule(
1085+
algo: Callable[..., list[int]],
1086+
parent_graph_module: torch.fx.GraphModule,
1087+
alignment: int,
1088+
submodule_node: torch.fx.Node,
1089+
graph_signature: Optional[ExportGraphSignature] = None,
1090+
alloc_graph_input: bool = False,
1091+
) -> list[int]:
1092+
"""Apply algo to nodes in a submodule of the graph module."""
1093+
assert submodule_node.op == "get_attr"
1094+
submodule = getattr(parent_graph_module, submodule_node.target)
1095+
1096+
logging.debug(f"Planning memory for submodule {submodule_node.name}...")
1097+
bufsizes = apply_algo(
1098+
algo,
1099+
submodule,
1100+
alignment,
1101+
graph_signature,
1102+
alloc_graph_input=alloc_graph_input,
1103+
alloc_graph_output=True,
1104+
)
1105+
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
1106+
logging.debug(f"Buffer sizes for submodule {submodule_node.name}: {bufsizes}")
1107+
return bufsizes
1108+
1109+
1110+
def _apply_algo_to_submodules(
1111+
algo: Callable[..., list[int]],
1112+
graph_module: torch.fx.GraphModule,
1113+
alignment: int,
1114+
graph_signature: Optional[ExportGraphSignature] = None,
1115+
) -> list[int]:
1116+
"""Apply algo to map/cond/while nodes in the graph module.
1117+
1118+
This method will popuate graph_module.meta["non_const_buffer_sizes"] for
1119+
all submodules and return a bufsizes list that is the maximum size of all
1120+
buffers.
1121+
"""
1122+
1123+
# Bufsizes for submodules.
1124+
bufsizes: list[int] = []
1125+
1126+
def _handle(
1127+
submodule_node: torch.fx.Node,
1128+
alloc_graph_input: bool = False,
1129+
) -> None:
1130+
current_bufsizes = _handle_submodule(
1131+
algo,
1132+
graph_module,
1133+
alignment,
1134+
submodule_node,
1135+
graph_signature,
1136+
alloc_graph_input=alloc_graph_input,
1137+
)
1138+
nonlocal bufsizes
1139+
_merge_bufsizes(bufsizes, current_bufsizes)
1140+
1141+
for cond_node in get_cond_nodes(graph_module):
1142+
_handle(cast(torch.fx.Node, cond_node.args[1]))
1143+
_handle(cast(torch.fx.Node, cond_node.args[2]))
1144+
1145+
for while_node in get_while_nodes(graph_module):
1146+
_handle(cast(torch.fx.Node, while_node.args[0]))
1147+
_handle(cast(torch.fx.Node, while_node.args[1]))
1148+
1149+
for map_node in get_map_nodes(graph_module):
1150+
_handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True)
1151+
1152+
# TODO: We can handle delegates the same way as map/cond/while.
1153+
# Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates.
1154+
1155+
return bufsizes
1156+
1157+
10651158
def apply_algo(
1066-
algo: Callable[
1067-
...,
1068-
List[int],
1069-
],
1159+
algo: Callable[..., list[int]],
10701160
graph_module: torch.fx.GraphModule,
10711161
alignment: int,
10721162
graph_signature: Optional[ExportGraphSignature] = None,
10731163
alloc_graph_input: bool = True,
10741164
alloc_graph_output: bool = True,
10751165
alloc_mutable_buffers: bool = True,
1076-
) -> List[int]:
1166+
) -> list[int]:
10771167
"""
10781168
Recursively apply algo to graph_module and its submodules for control flow.
10791169
1080-
Quite naively right now since it does not take the following optimizations
1081-
into considerating:
1082-
1. for conditional structure, true branch and false true does not overlap
1083-
in lifetime and can share tensor storage
1084-
2. tensors inside a submodule (e.g. true branch) has opportunities to share
1085-
storage with tensors in the outer module.
1086-
TODO: make these optimizations once we have some baseline working.
1170+
Algo implementation should handle one of two meta entries for submodules:
1171+
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
1172+
`algo` should start at the offset specified by this list;
1173+
OR
1174+
2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule.
1175+
`algo` should reserve the space specified by this list for the lifetime
1176+
of the submodule node (e.g. cond, while, map).
1177+
1178+
TODO: Missing optimizations:
1179+
1. To handle maps, we set `alloc_graph_input=True`, which allocates
1180+
appropriate space for mapped arg but ends up allocating extra space for
1181+
`operand` arg. The memory for operands is unused.
10871182
"""
10881183
# Extract the nodes and their lifespans from the graph_module
10891184
# Difficult to just filter the list of specs returned by this due to
10901185
# how we flag trainable weights.
10911186
_ = update_all_tensors_lifetime(graph_module, graph_signature)
1187+
10921188
# Filter specs based on alloc_graph_input and alloc_graph_output
10931189
specs = collect_specs_from_nodes(
10941190
graph_module.graph.nodes,
@@ -1099,55 +1195,35 @@ def apply_algo(
10991195
ignore_mutable_buffers=not alloc_mutable_buffers,
11001196
)
11011197

1198+
# Get temporary specs for submodules to set aside space during execution
1199+
# of submodules.
1200+
submodule_bufsizes = _apply_algo_to_submodules(
1201+
algo, graph_module, alignment, graph_signature
1202+
)
1203+
1204+
# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
1205+
# algos to work using `input_mem_buffer_sizes` or use
1206+
# `non_const_buffer_sizes` directly.
1207+
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1208+
graph_module.input_mem_buffer_sizes = submodule_bufsizes
1209+
11021210
# Get extra padding for XNNPACK if needed
11031211
extra_padding = 0
11041212
if _contains_xnnpack_delegate(graph_module):
11051213
extra_padding = 64
11061214

11071215
# Pass the filtered specs to the algorithm
1108-
bufsizes: List[int] = algo(
1216+
bufsizes: list[int] = algo(
11091217
alignment,
11101218
specs,
11111219
graph_module,
11121220
graph_signature,
11131221
extra_padding,
11141222
)
11151223

1116-
insert_calls_to_free(graph_module, set(specs))
1117-
1118-
def handle_submodule(
1119-
submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
1120-
) -> None:
1121-
nonlocal bufsizes
1122-
assert submodule_nd.op == "get_attr"
1123-
submodule = getattr(graph_module, submodule_nd.target)
1124-
# memory planning for submodule need to be aware of the amount of
1125-
# buffer already allocated.
1126-
submodule.input_mem_buffer_sizes = bufsizes
1127-
1128-
bufsizes = apply_algo(
1129-
algo,
1130-
submodule,
1131-
alignment,
1132-
graph_signature,
1133-
alloc_graph_input=alloc_graph_input,
1134-
alloc_graph_output=True,
1135-
)
1136-
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
1137-
1138-
for cond_node in get_cond_nodes(graph_module):
1139-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
1140-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))
1141-
1142-
for while_node in get_while_nodes(graph_module):
1143-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
1144-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
1145-
# TODO: Add test coverage for map operator once dynamo tracing is
1146-
# fully supported for this. T142287208
1147-
for map_node in get_map_nodes(graph_module):
1148-
handle_submodule(
1149-
typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
1150-
)
1224+
# pyre-ignore[6]: Incompatible parameter type [6]
1225+
# In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]`
1226+
insert_calls_to_free(graph_module, specs)
11511227

11521228
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
11531229
return bufsizes

0 commit comments

Comments
 (0)