10
10
import itertools
11
11
import logging
12
12
import operator
13
- import typing
14
13
from collections import defaultdict
15
14
from dataclasses import dataclass , field
16
- from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Tuple , Union
15
+ from typing import (
16
+ Any ,
17
+ Callable ,
18
+ cast ,
19
+ Dict ,
20
+ Iterable ,
21
+ List ,
22
+ Optional ,
23
+ Set ,
24
+ Tuple ,
25
+ Union ,
26
+ )
17
27
18
28
import torch
19
29
from executorch .exir import memory
@@ -960,7 +970,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
960
970
bufsizes = getattr (graph_module , "input_mem_buffer_sizes" , None )
961
971
if bufsizes is None :
962
972
bufsizes = [0 , 0 ]
963
- bufsizes = typing . cast (List [int ], bufsizes )
973
+ bufsizes = cast (List [int ], bufsizes )
964
974
965
975
for spec in specs :
966
976
spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
@@ -1062,33 +1072,119 @@ def insert_calls_to_free(
1062
1072
graph_module .recompile ()
1063
1073
1064
1074
1075
+ def _merge_bufsizes (bufsizes : list [int ], new_bufsizes : list [int ]) -> list [int ]:
1076
+ """Combine two buffer size lists."""
1077
+ if len (bufsizes ) < len (new_bufsizes ):
1078
+ bufsizes .extend ([0 ] * (len (new_bufsizes ) - len (bufsizes )))
1079
+ for i in range (len (new_bufsizes )):
1080
+ bufsizes [i ] = max (bufsizes [i ], new_bufsizes [i ])
1081
+ return bufsizes
1082
+
1083
+
1084
+ def _handle_submodule (
1085
+ algo : Callable [..., list [int ]],
1086
+ parent_graph_module : torch .fx .GraphModule ,
1087
+ alignment : int ,
1088
+ submodule_node : torch .fx .Node ,
1089
+ graph_signature : Optional [ExportGraphSignature ] = None ,
1090
+ alloc_graph_input : bool = False ,
1091
+ ) -> list [int ]:
1092
+ """Apply algo to nodes in a submodule of the graph module."""
1093
+ assert submodule_node .op == "get_attr"
1094
+ submodule = getattr (parent_graph_module , submodule_node .target )
1095
+
1096
+ logging .debug (f"Planning memory for submodule { submodule_node .name } ..." )
1097
+ bufsizes = apply_algo (
1098
+ algo ,
1099
+ submodule ,
1100
+ alignment ,
1101
+ graph_signature ,
1102
+ alloc_graph_input = alloc_graph_input ,
1103
+ alloc_graph_output = True ,
1104
+ )
1105
+ submodule .meta .update ({"non_const_buffer_sizes" : bufsizes })
1106
+ logging .debug (f"Buffer sizes for submodule { submodule_node .name } : { bufsizes } " )
1107
+ return bufsizes
1108
+
1109
+
1110
+ def _apply_algo_to_submodules (
1111
+ algo : Callable [..., list [int ]],
1112
+ graph_module : torch .fx .GraphModule ,
1113
+ alignment : int ,
1114
+ graph_signature : Optional [ExportGraphSignature ] = None ,
1115
+ ) -> list [int ]:
1116
+ """Apply algo to map/cond/while nodes in the graph module.
1117
+
1118
+ This method will popuate graph_module.meta["non_const_buffer_sizes"] for
1119
+ all submodules and return a bufsizes list that is the maximum size of all
1120
+ buffers.
1121
+ """
1122
+
1123
+ # Bufsizes for submodules.
1124
+ bufsizes : list [int ] = []
1125
+
1126
+ def _handle (
1127
+ submodule_node : torch .fx .Node ,
1128
+ alloc_graph_input : bool = False ,
1129
+ ) -> None :
1130
+ current_bufsizes = _handle_submodule (
1131
+ algo ,
1132
+ graph_module ,
1133
+ alignment ,
1134
+ submodule_node ,
1135
+ graph_signature ,
1136
+ alloc_graph_input = alloc_graph_input ,
1137
+ )
1138
+ nonlocal bufsizes
1139
+ _merge_bufsizes (bufsizes , current_bufsizes )
1140
+
1141
+ for cond_node in get_cond_nodes (graph_module ):
1142
+ _handle (cast (torch .fx .Node , cond_node .args [1 ]))
1143
+ _handle (cast (torch .fx .Node , cond_node .args [2 ]))
1144
+
1145
+ for while_node in get_while_nodes (graph_module ):
1146
+ _handle (cast (torch .fx .Node , while_node .args [0 ]))
1147
+ _handle (cast (torch .fx .Node , while_node .args [1 ]))
1148
+
1149
+ for map_node in get_map_nodes (graph_module ):
1150
+ _handle (cast (torch .fx .Node , map_node .args [0 ]), alloc_graph_input = True )
1151
+
1152
+ # TODO: We can handle delegates the same way as map/cond/while.
1153
+ # Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates.
1154
+
1155
+ return bufsizes
1156
+
1157
+
1065
1158
def apply_algo (
1066
- algo : Callable [
1067
- ...,
1068
- List [int ],
1069
- ],
1159
+ algo : Callable [..., list [int ]],
1070
1160
graph_module : torch .fx .GraphModule ,
1071
1161
alignment : int ,
1072
1162
graph_signature : Optional [ExportGraphSignature ] = None ,
1073
1163
alloc_graph_input : bool = True ,
1074
1164
alloc_graph_output : bool = True ,
1075
1165
alloc_mutable_buffers : bool = True ,
1076
- ) -> List [int ]:
1166
+ ) -> list [int ]:
1077
1167
"""
1078
1168
Recursively apply algo to graph_module and its submodules for control flow.
1079
1169
1080
- Quite naively right now since it does not take the following optimizations
1081
- into considerating:
1082
- 1. for conditional structure, true branch and false true does not overlap
1083
- in lifetime and can share tensor storage
1084
- 2. tensors inside a submodule (e.g. true branch) has opportunities to share
1085
- storage with tensors in the outer module.
1086
- TODO: make these optimizations once we have some baseline working.
1170
+ Algo implementation should handle one of two meta entries for submodules:
1171
+ 1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
1172
+ `algo` should start at the offset specified by this list;
1173
+ OR
1174
+ 2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule.
1175
+ `algo` should reserve the space specified by this list for the lifetime
1176
+ of the submodule node (e.g. cond, while, map).
1177
+
1178
+ TODO: Missing optimizations:
1179
+ 1. To handle maps, we set `alloc_graph_input=True`, which allocates
1180
+ appropriate space for mapped arg but ends up allocating extra space for
1181
+ `operand` arg. The memory for operands is unused.
1087
1182
"""
1088
1183
# Extract the nodes and their lifespans from the graph_module
1089
1184
# Difficult to just filter the list of specs returned by this due to
1090
1185
# how we flag trainable weights.
1091
1186
_ = update_all_tensors_lifetime (graph_module , graph_signature )
1187
+
1092
1188
# Filter specs based on alloc_graph_input and alloc_graph_output
1093
1189
specs = collect_specs_from_nodes (
1094
1190
graph_module .graph .nodes ,
@@ -1099,55 +1195,35 @@ def apply_algo(
1099
1195
ignore_mutable_buffers = not alloc_mutable_buffers ,
1100
1196
)
1101
1197
1198
+ # Get temporary specs for submodules to set aside space during execution
1199
+ # of submodules.
1200
+ submodule_bufsizes = _apply_algo_to_submodules (
1201
+ algo , graph_module , alignment , graph_signature
1202
+ )
1203
+
1204
+ # Update `input_mem_buffer_sizes` in graph_module. This will allow existing
1205
+ # algos to work using `input_mem_buffer_sizes` or use
1206
+ # `non_const_buffer_sizes` directly.
1207
+ # pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1208
+ graph_module .input_mem_buffer_sizes = submodule_bufsizes
1209
+
1102
1210
# Get extra padding for XNNPACK if needed
1103
1211
extra_padding = 0
1104
1212
if _contains_xnnpack_delegate (graph_module ):
1105
1213
extra_padding = 64
1106
1214
1107
1215
# Pass the filtered specs to the algorithm
1108
- bufsizes : List [int ] = algo (
1216
+ bufsizes : list [int ] = algo (
1109
1217
alignment ,
1110
1218
specs ,
1111
1219
graph_module ,
1112
1220
graph_signature ,
1113
1221
extra_padding ,
1114
1222
)
1115
1223
1116
- insert_calls_to_free (graph_module , set (specs ))
1117
-
1118
- def handle_submodule (
1119
- submodule_nd : torch .fx .Node , alloc_graph_input : bool = False
1120
- ) -> None :
1121
- nonlocal bufsizes
1122
- assert submodule_nd .op == "get_attr"
1123
- submodule = getattr (graph_module , submodule_nd .target )
1124
- # memory planning for submodule need to be aware of the amount of
1125
- # buffer already allocated.
1126
- submodule .input_mem_buffer_sizes = bufsizes
1127
-
1128
- bufsizes = apply_algo (
1129
- algo ,
1130
- submodule ,
1131
- alignment ,
1132
- graph_signature ,
1133
- alloc_graph_input = alloc_graph_input ,
1134
- alloc_graph_output = True ,
1135
- )
1136
- submodule .meta .update ({"non_const_buffer_sizes" : bufsizes })
1137
-
1138
- for cond_node in get_cond_nodes (graph_module ):
1139
- handle_submodule (typing .cast (torch .fx .Node , cond_node .args [1 ]))
1140
- handle_submodule (typing .cast (torch .fx .Node , cond_node .args [2 ]))
1141
-
1142
- for while_node in get_while_nodes (graph_module ):
1143
- handle_submodule (typing .cast (torch .fx .Node , while_node .args [0 ]))
1144
- handle_submodule (typing .cast (torch .fx .Node , while_node .args [1 ]))
1145
- # TODO: Add test coverage for map operator once dynamo tracing is
1146
- # fully supported for this. T142287208
1147
- for map_node in get_map_nodes (graph_module ):
1148
- handle_submodule (
1149
- typing .cast (torch .fx .Node , map_node .args [0 ]), alloc_graph_input = True
1150
- )
1224
+ # pyre-ignore[6]: Incompatible parameter type [6]
1225
+ # In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]`
1226
+ insert_calls_to_free (graph_module , specs )
1151
1227
1152
1228
graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
1153
1229
return bufsizes
0 commit comments