1010import itertools
1111import logging
1212import operator
13- import typing
1413from collections import defaultdict
1514from dataclasses import dataclass , field
16- from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Tuple , Union
15+ from typing import (
16+ Any ,
17+ Callable ,
18+ cast ,
19+ Dict ,
20+ Iterable ,
21+ List ,
22+ Optional ,
23+ Set ,
24+ Tuple ,
25+ Union ,
26+ )
1727
1828import torch
1929from executorch .exir import memory
@@ -960,7 +970,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
960970 bufsizes = getattr (graph_module , "input_mem_buffer_sizes" , None )
961971 if bufsizes is None :
962972 bufsizes = [0 , 0 ]
963- bufsizes = typing . cast (List [int ], bufsizes )
973+ bufsizes = cast (List [int ], bufsizes )
964974
965975 for spec in specs :
966976 spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
@@ -1062,33 +1072,119 @@ def insert_calls_to_free(
10621072 graph_module .recompile ()
10631073
10641074
1075+ def _merge_bufsizes (bufsizes : list [int ], new_bufsizes : list [int ]) -> list [int ]:
1076+ """Combine two buffer size lists."""
1077+ if len (bufsizes ) < len (new_bufsizes ):
1078+ bufsizes .extend ([0 ] * (len (new_bufsizes ) - len (bufsizes )))
1079+ for i in range (len (new_bufsizes )):
1080+ bufsizes [i ] = max (bufsizes [i ], new_bufsizes [i ])
1081+ return bufsizes
1082+
1083+
1084+ def _handle_submodule (
1085+ algo : Callable [..., list [int ]],
1086+ parent_graph_module : torch .fx .GraphModule ,
1087+ alignment : int ,
1088+ submodule_node : torch .fx .Node ,
1089+ graph_signature : Optional [ExportGraphSignature ] = None ,
1090+ alloc_graph_input : bool = False ,
1091+ ) -> list [int ]:
1092+ """Apply algo to nodes in a submodule of the graph module."""
1093+ assert submodule_node .op == "get_attr"
1094+ submodule = getattr (parent_graph_module , submodule_node .target )
1095+
1096+ logging .debug (f"Planning memory for submodule { submodule_node .name } ..." )
1097+ bufsizes = apply_algo (
1098+ algo ,
1099+ submodule ,
1100+ alignment ,
1101+ graph_signature ,
1102+ alloc_graph_input = alloc_graph_input ,
1103+ alloc_graph_output = True ,
1104+ )
1105+ submodule .meta .update ({"non_const_buffer_sizes" : bufsizes })
1106+ logging .debug (f"Buffer sizes for submodule { submodule_node .name } : { bufsizes } " )
1107+ return bufsizes
1108+
1109+
1110+ def _apply_algo_to_submodules (
1111+ algo : Callable [..., list [int ]],
1112+ graph_module : torch .fx .GraphModule ,
1113+ alignment : int ,
1114+ graph_signature : Optional [ExportGraphSignature ] = None ,
1115+ ) -> list [int ]:
1116+ """Apply algo to map/cond/while nodes in the graph module.
1117+
1118+ This method will popuate graph_module.meta["non_const_buffer_sizes"] for
1119+ all submodules and return a bufsizes list that is the maximum size of all
1120+ buffers.
1121+ """
1122+
1123+ # Bufsizes for submodules.
1124+ bufsizes : list [int ] = []
1125+
1126+ def _handle (
1127+ submodule_node : torch .fx .Node ,
1128+ alloc_graph_input : bool = False ,
1129+ ) -> None :
1130+ current_bufsizes = _handle_submodule (
1131+ algo ,
1132+ graph_module ,
1133+ alignment ,
1134+ submodule_node ,
1135+ graph_signature ,
1136+ alloc_graph_input = alloc_graph_input ,
1137+ )
1138+ nonlocal bufsizes
1139+ _merge_bufsizes (bufsizes , current_bufsizes )
1140+
1141+ for cond_node in get_cond_nodes (graph_module ):
1142+ _handle (cast (torch .fx .Node , cond_node .args [1 ]))
1143+ _handle (cast (torch .fx .Node , cond_node .args [2 ]))
1144+
1145+ for while_node in get_while_nodes (graph_module ):
1146+ _handle (cast (torch .fx .Node , while_node .args [0 ]))
1147+ _handle (cast (torch .fx .Node , while_node .args [1 ]))
1148+
1149+ for map_node in get_map_nodes (graph_module ):
1150+ _handle (cast (torch .fx .Node , map_node .args [0 ]), alloc_graph_input = True )
1151+
1152+ # TODO: We can handle delegates the same way as map/cond/while.
1153+ # Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates.
1154+
1155+ return bufsizes
1156+
1157+
10651158def apply_algo (
1066- algo : Callable [
1067- ...,
1068- List [int ],
1069- ],
1159+ algo : Callable [..., list [int ]],
10701160 graph_module : torch .fx .GraphModule ,
10711161 alignment : int ,
10721162 graph_signature : Optional [ExportGraphSignature ] = None ,
10731163 alloc_graph_input : bool = True ,
10741164 alloc_graph_output : bool = True ,
10751165 alloc_mutable_buffers : bool = True ,
1076- ) -> List [int ]:
1166+ ) -> list [int ]:
10771167 """
10781168 Recursively apply algo to graph_module and its submodules for control flow.
10791169
1080- Quite naively right now since it does not take the following optimizations
1081- into considerating:
1082- 1. for conditional structure, true branch and false true does not overlap
1083- in lifetime and can share tensor storage
1084- 2. tensors inside a submodule (e.g. true branch) has opportunities to share
1085- storage with tensors in the outer module.
1086- TODO: make these optimizations once we have some baseline working.
1170+ Algo implementation should handle one of two meta entries for submodules:
1171+ 1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
1172+ `algo` should start at the offset specified by this list;
1173+ OR
1174+ 2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule.
1175+ `algo` should reserve the space specified by this list for the lifetime
1176+ of the submodule node (e.g. cond, while, map).
1177+
1178+ TODO: Missing optimizations:
1179+ 1. To handle maps, we set `alloc_graph_input=True`, which allocates
1180+ appropriate space for mapped arg but ends up allocating extra space for
1181+ `operand` arg. The memory for operands is unused.
10871182 """
10881183 # Extract the nodes and their lifespans from the graph_module
10891184 # Difficult to just filter the list of specs returned by this due to
10901185 # how we flag trainable weights.
10911186 _ = update_all_tensors_lifetime (graph_module , graph_signature )
1187+
10921188 # Filter specs based on alloc_graph_input and alloc_graph_output
10931189 specs = collect_specs_from_nodes (
10941190 graph_module .graph .nodes ,
@@ -1099,55 +1195,35 @@ def apply_algo(
10991195 ignore_mutable_buffers = not alloc_mutable_buffers ,
11001196 )
11011197
1198+ # Get temporary specs for submodules to set aside space during execution
1199+ # of submodules.
1200+ submodule_bufsizes = _apply_algo_to_submodules (
1201+ algo , graph_module , alignment , graph_signature
1202+ )
1203+
1204+ # Update `input_mem_buffer_sizes` in graph_module. This will allow existing
1205+ # algos to work using `input_mem_buffer_sizes` or use
1206+ # `non_const_buffer_sizes` directly.
1207+ # pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
1208+ graph_module .input_mem_buffer_sizes = submodule_bufsizes
1209+
11021210 # Get extra padding for XNNPACK if needed
11031211 extra_padding = 0
11041212 if _contains_xnnpack_delegate (graph_module ):
11051213 extra_padding = 64
11061214
11071215 # Pass the filtered specs to the algorithm
1108- bufsizes : List [int ] = algo (
1216+ bufsizes : list [int ] = algo (
11091217 alignment ,
11101218 specs ,
11111219 graph_module ,
11121220 graph_signature ,
11131221 extra_padding ,
11141222 )
11151223
1116- insert_calls_to_free (graph_module , set (specs ))
1117-
1118- def handle_submodule (
1119- submodule_nd : torch .fx .Node , alloc_graph_input : bool = False
1120- ) -> None :
1121- nonlocal bufsizes
1122- assert submodule_nd .op == "get_attr"
1123- submodule = getattr (graph_module , submodule_nd .target )
1124- # memory planning for submodule need to be aware of the amount of
1125- # buffer already allocated.
1126- submodule .input_mem_buffer_sizes = bufsizes
1127-
1128- bufsizes = apply_algo (
1129- algo ,
1130- submodule ,
1131- alignment ,
1132- graph_signature ,
1133- alloc_graph_input = alloc_graph_input ,
1134- alloc_graph_output = True ,
1135- )
1136- submodule .meta .update ({"non_const_buffer_sizes" : bufsizes })
1137-
1138- for cond_node in get_cond_nodes (graph_module ):
1139- handle_submodule (typing .cast (torch .fx .Node , cond_node .args [1 ]))
1140- handle_submodule (typing .cast (torch .fx .Node , cond_node .args [2 ]))
1141-
1142- for while_node in get_while_nodes (graph_module ):
1143- handle_submodule (typing .cast (torch .fx .Node , while_node .args [0 ]))
1144- handle_submodule (typing .cast (torch .fx .Node , while_node .args [1 ]))
1145- # TODO: Add test coverage for map operator once dynamo tracing is
1146- # fully supported for this. T142287208
1147- for map_node in get_map_nodes (graph_module ):
1148- handle_submodule (
1149- typing .cast (torch .fx .Node , map_node .args [0 ]), alloc_graph_input = True
1150- )
1224+ # pyre-ignore[6]: Incompatible parameter type [6]
1225+ # In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]`
1226+ insert_calls_to_free (graph_module , specs )
11511227
11521228 graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
11531229 return bufsizes
0 commit comments