|
10 | 10 | import itertools |
11 | 11 | import logging |
12 | 12 | import operator |
13 | | -import typing |
14 | 13 | from collections import defaultdict |
15 | 14 | from dataclasses import dataclass, field |
16 | | -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union |
| 15 | +from typing import ( |
| 16 | + Any, |
| 17 | + Callable, |
| 18 | + cast, |
| 19 | + Dict, |
| 20 | + Iterable, |
| 21 | + List, |
| 22 | + Optional, |
| 23 | + Set, |
| 24 | + Tuple, |
| 25 | + Union, |
| 26 | +) |
17 | 27 |
|
18 | 28 | import torch |
19 | 29 | from executorch.exir import memory |
@@ -949,7 +959,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int: |
949 | 959 | bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None) |
950 | 960 | if bufsizes is None: |
951 | 961 | bufsizes = [0, 0] |
952 | | - bufsizes = typing.cast(List[int], bufsizes) |
| 962 | + bufsizes = cast(List[int], bufsizes) |
953 | 963 |
|
954 | 964 | for spec in specs: |
955 | 965 | spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0)) |
@@ -1051,92 +1061,155 @@ def insert_calls_to_free( |
1051 | 1061 | graph_module.recompile() |
1052 | 1062 |
|
1053 | 1063 |
|
| 1064 | +def _merge_bufsizes(bufsizes: list[int], new_bufsizes: list[int]) -> list[int]: |
| 1065 | + """Combine two buffer size lists.""" |
| 1066 | + if len(bufsizes) < len(new_bufsizes): |
| 1067 | + bufsizes.extend([0] * (len(new_bufsizes) - len(bufsizes))) |
| 1068 | + for i in range(len(new_bufsizes)): |
| 1069 | + bufsizes[i] = max(bufsizes[i], new_bufsizes[i]) |
| 1070 | + return bufsizes |
| 1071 | + |
| 1072 | + |
| 1073 | +def _handle_submodule( |
| 1074 | + algo: Callable[..., list[int]], |
| 1075 | + parent_graph_module: torch.fx.GraphModule, |
| 1076 | + alignment: int, |
| 1077 | + submodule_node: torch.fx.Node, |
| 1078 | + graph_signature: Optional[ExportGraphSignature] = None, |
| 1079 | + alloc_graph_input: bool = False, |
| 1080 | +) -> list[int]: |
| 1081 | + """Apply algo to nodes in a submodule of the graph module.""" |
| 1082 | + assert submodule_node.op == "get_attr" |
| 1083 | + submodule = getattr(parent_graph_module, submodule_node.target) |
| 1084 | + |
| 1085 | + logging.debug(f"Planning memory for submodule {submodule_node.name}...") |
| 1086 | + bufsizes = apply_algo( |
| 1087 | + algo, |
| 1088 | + submodule, |
| 1089 | + alignment, |
| 1090 | + graph_signature, |
| 1091 | + alloc_graph_input=alloc_graph_input, |
| 1092 | + alloc_graph_output=True, |
| 1093 | + ) |
| 1094 | + submodule.meta.update({"non_const_buffer_sizes": bufsizes}) |
| 1095 | + logging.debug(f"Buffer sizes for submodule {submodule_node.name}: {bufsizes}") |
| 1096 | + return bufsizes |
| 1097 | + |
| 1098 | + |
| 1099 | +def _apply_algo_to_submodules( |
| 1100 | + algo: Callable[..., list[int]], |
| 1101 | + graph_module: torch.fx.GraphModule, |
| 1102 | + alignment: int, |
| 1103 | + graph_signature: Optional[ExportGraphSignature] = None, |
| 1104 | +) -> list[int]: |
| 1105 | + """Apply algo to map/cond/while nodes in the graph module. |
| 1106 | +
|
| 1107 | + This method will popuate graph_module.meta["non_const_buffer_sizes"] for |
| 1108 | + all submodules and return a bufsizes list that is the maximum size of all |
| 1109 | + buffers. |
| 1110 | + """ |
| 1111 | + |
| 1112 | + # Bufsizes for submodules. |
| 1113 | + bufsizes: list[int] = [] |
| 1114 | + |
| 1115 | + def _handle( |
| 1116 | + submodule_node: torch.fx.Node, |
| 1117 | + alloc_graph_input: bool = False, |
| 1118 | + ) -> None: |
| 1119 | + current_bufsizes = _handle_submodule( |
| 1120 | + algo, |
| 1121 | + graph_module, |
| 1122 | + alignment, |
| 1123 | + submodule_node, |
| 1124 | + graph_signature, |
| 1125 | + alloc_graph_input=alloc_graph_input, |
| 1126 | + ) |
| 1127 | + nonlocal bufsizes |
| 1128 | + _merge_bufsizes(bufsizes, current_bufsizes) |
| 1129 | + |
| 1130 | + for cond_node in get_cond_nodes(graph_module): |
| 1131 | + _handle(cast(torch.fx.Node, cond_node.args[1])) |
| 1132 | + _handle(cast(torch.fx.Node, cond_node.args[2])) |
| 1133 | + |
| 1134 | + for while_node in get_while_nodes(graph_module): |
| 1135 | + _handle(cast(torch.fx.Node, while_node.args[0])) |
| 1136 | + _handle(cast(torch.fx.Node, while_node.args[1])) |
| 1137 | + |
| 1138 | + for map_node in get_map_nodes(graph_module): |
| 1139 | + _handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True) |
| 1140 | + |
| 1141 | + # TODO: We can handle delegates the same way as map/cond/while. |
| 1142 | + # Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates. |
| 1143 | + |
| 1144 | + return bufsizes |
| 1145 | + |
| 1146 | + |
1054 | 1147 | def apply_algo( |
1055 | | - algo: Callable[ |
1056 | | - ..., |
1057 | | - List[int], |
1058 | | - ], |
| 1148 | + algo: Callable[..., list[int]], |
1059 | 1149 | graph_module: torch.fx.GraphModule, |
1060 | 1150 | alignment: int, |
1061 | 1151 | graph_signature: Optional[ExportGraphSignature] = None, |
1062 | 1152 | alloc_graph_input: bool = True, |
1063 | 1153 | alloc_graph_output: bool = True, |
1064 | 1154 | alloc_mutable_buffers: bool = True, |
1065 | | -) -> List[int]: |
| 1155 | +) -> list[int]: |
1066 | 1156 | """ |
1067 | 1157 | Recursively apply algo to graph_module and its submodules for control flow. |
1068 | 1158 |
|
1069 | | - Quite naively right now since it does not take the following optimizations |
1070 | | - into considerating: |
1071 | | - 1. for conditional structure, true branch and false true does not overlap |
1072 | | - in lifetime and can share tensor storage |
1073 | | - 2. tensors inside a submodule (e.g. true branch) has opportunities to share |
1074 | | - storage with tensors in the outer module. |
1075 | | - TODO: make these optimizations once we have some baseline working. |
| 1159 | + Algo implementation should handle one of two meta entries for submodules: |
| 1160 | + 1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by |
| 1161 | + `algo` should start at the offset specified by this list; |
| 1162 | + OR |
| 1163 | + 2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule. |
| 1164 | + `algo` should reserve the space specified by this list for the lifetime |
| 1165 | + of the submodule node (e.g. cond, while, map). |
| 1166 | +
|
| 1167 | + TODO: Missing optimizations: |
| 1168 | + 1. To handle maps, we set `alloc_graph_input=True`, which allocates |
| 1169 | + appropriate space for mapped arg but ends up allocating extra space for |
| 1170 | + `operand` arg. The memory for operands is unused. |
1076 | 1171 | """ |
1077 | 1172 | # Extract the nodes and their lifespans from the graph_module |
1078 | 1173 | # Difficult to just filter the list of specs returned by this due to |
1079 | 1174 | # how we flag trainable weights. |
1080 | 1175 | _ = update_all_tensors_lifetime(graph_module, graph_signature) |
| 1176 | + |
1081 | 1177 | # Filter specs based on alloc_graph_input and alloc_graph_output |
1082 | | - specs = collect_specs_from_nodes( |
1083 | | - graph_module.graph.nodes, |
1084 | | - graph_signature, |
1085 | | - do_assertion=False, |
1086 | | - ignore_graph_input=not alloc_graph_input, |
1087 | | - ignore_graph_output=not alloc_graph_output, |
1088 | | - ignore_mutable_buffers=not alloc_mutable_buffers, |
| 1178 | + specs = set( |
| 1179 | + collect_specs_from_nodes( |
| 1180 | + graph_module.graph.nodes, |
| 1181 | + graph_signature, |
| 1182 | + do_assertion=False, |
| 1183 | + ignore_graph_input=not alloc_graph_input, |
| 1184 | + ignore_graph_output=not alloc_graph_output, |
| 1185 | + ignore_mutable_buffers=not alloc_mutable_buffers, |
| 1186 | + ) |
1089 | 1187 | ) |
1090 | 1188 |
|
| 1189 | + # Get temporary specs for submodules to set aside space during execution |
| 1190 | + # of submodules. |
| 1191 | + submodule_bufsizes = _apply_algo_to_submodules(algo, graph_module, alignment, graph_signature) |
| 1192 | + |
| 1193 | + # Update `input_mem_buffer_sizes` in graph_module.meta. This will allow |
| 1194 | + # existing algos to work using `input_mem_buffer_sizes` or use |
| 1195 | + # `non_const_buffer_sizes` directly. |
| 1196 | + graph_module.meta.update({"input_mem_buffer_sizes": submodule_bufsizes}) |
| 1197 | + |
1091 | 1198 | # Get extra padding for XNNPACK if needed |
1092 | 1199 | extra_padding = 0 |
1093 | 1200 | if _contains_xnnpack_delegate(graph_module): |
1094 | 1201 | extra_padding = 64 |
1095 | 1202 |
|
1096 | 1203 | # Pass the filtered specs to the algorithm |
1097 | | - bufsizes: List[int] = algo( |
| 1204 | + bufsizes: list[int] = algo( |
1098 | 1205 | alignment, |
1099 | 1206 | specs, |
1100 | 1207 | graph_module, |
1101 | 1208 | graph_signature, |
1102 | 1209 | extra_padding, |
1103 | 1210 | ) |
1104 | 1211 |
|
1105 | | - insert_calls_to_free(graph_module, set(specs)) |
1106 | | - |
1107 | | - def handle_submodule( |
1108 | | - submodule_nd: torch.fx.Node, alloc_graph_input: bool = False |
1109 | | - ) -> None: |
1110 | | - nonlocal bufsizes |
1111 | | - assert submodule_nd.op == "get_attr" |
1112 | | - submodule = getattr(graph_module, submodule_nd.target) |
1113 | | - # memory planning for submodule need to be aware of the amount of |
1114 | | - # buffer already allocated. |
1115 | | - submodule.input_mem_buffer_sizes = bufsizes |
1116 | | - |
1117 | | - bufsizes = apply_algo( |
1118 | | - algo, |
1119 | | - submodule, |
1120 | | - alignment, |
1121 | | - graph_signature, |
1122 | | - alloc_graph_input=alloc_graph_input, |
1123 | | - alloc_graph_output=True, |
1124 | | - ) |
1125 | | - submodule.meta.update({"non_const_buffer_sizes": bufsizes}) |
1126 | | - |
1127 | | - for cond_node in get_cond_nodes(graph_module): |
1128 | | - handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1])) |
1129 | | - handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2])) |
1130 | | - |
1131 | | - for while_node in get_while_nodes(graph_module): |
1132 | | - handle_submodule(typing.cast(torch.fx.Node, while_node.args[0])) |
1133 | | - handle_submodule(typing.cast(torch.fx.Node, while_node.args[1])) |
1134 | | - # TODO: Add test coverage for map operator once dynamo tracing is |
1135 | | - # fully supported for this. T142287208 |
1136 | | - for map_node in get_map_nodes(graph_module): |
1137 | | - handle_submodule( |
1138 | | - typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True |
1139 | | - ) |
| 1212 | + insert_calls_to_free(graph_module, specs) |
1140 | 1213 |
|
1141 | 1214 | graph_module.meta.update({"non_const_buffer_sizes": bufsizes}) |
1142 | 1215 | return bufsizes |
0 commit comments