@@ -46,6 +46,7 @@ def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
4646
4747def collect_specs_from_graph_module (
4848 graph_module : torch .fx .GraphModule ,
49+ graph_signature : ExportGraphSignature ,
4950 alloc_graph_input : bool ,
5051 alloc_graph_output : bool ,
5152) -> Iterable [TensorSpec ]:
@@ -56,6 +57,7 @@ def collect_specs_from_graph_module(
5657 # Collect the specs from all the nodes in the graph module, and return it
5758 return collect_specs_from_nodes (
5859 graph_module .graph .nodes ,
60+ graph_signature ,
5961 ignore_graph_input = not alloc_graph_input ,
6062 ignore_graph_output = not alloc_graph_output ,
6163 )
@@ -107,7 +109,7 @@ def memory_available(spec: TensorSpec) -> bool:
107109 # Iterate over all the specs in sorted order
108110 for spec in sorted (
109111 collect_specs_from_graph_module (
110- graph_module , alloc_graph_input , alloc_graph_output
112+ graph_module , graph_signature , alloc_graph_input , alloc_graph_output
111113 ),
112114 key = lambda spec : spec .allocated_memory ,
113115 reverse = True ,
@@ -182,7 +184,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
182184 # Iterate over all the specs in sorted order
183185 for spec in sorted (
184186 collect_specs_from_graph_module (
185- graph_module , alloc_graph_input , alloc_graph_output
187+ graph_module , graph_signature , alloc_graph_input , alloc_graph_output
186188 ),
187189 key = lambda spec : spec .allocated_memory ,
188190 reverse = True ,
@@ -250,6 +252,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
250252
251253def find_peak_memory_usages_per_memory (
252254 graph_module : torch .fx .GraphModule ,
255+ graph_signature : ExportGraphSignature ,
253256 alloc_graph_input : bool ,
254257 alloc_graph_output : bool ,
255258 mem_constraints : Optional [MemConstraints ] = None ,
@@ -265,7 +268,7 @@ def find_peak_memory_usages_per_memory(
265268
266269 # go through all nodes in the graph, collect memory usage per spec.mem_id
267270 for spec in collect_specs_from_graph_module (
268- graph_module , alloc_graph_input , alloc_graph_output
271+ graph_module , graph_signature , alloc_graph_input , alloc_graph_output
269272 ):
270273 if mem_constraints is not None and mem_constraints .skipped_spec (spec ):
271274 continue
@@ -288,6 +291,7 @@ def find_peak_memory_usages_per_memory(
288291
289292def find_peak_memory_usage (
290293 graph_module : torch .fx .GraphModule ,
294+ graph_signature : ExportGraphSignature ,
291295 alloc_graph_input : bool ,
292296 alloc_graph_output : bool ,
293297 mem_constraints : Optional [MemConstraints ] = None ,
@@ -303,7 +307,7 @@ def find_peak_memory_usage(
303307
304308 # Iterate over all the node specs
305309 for spec in collect_specs_from_graph_module (
306- graph_module , alloc_graph_input , alloc_graph_output
310+ graph_module , graph_signature , alloc_graph_input , alloc_graph_output
307311 ):
308312 if spec .lifetime [0 ] is None or (
309313 mem_constraints is not None and mem_constraints .skipped_spec (spec )
@@ -358,6 +362,7 @@ def print_memory_planning_info(
358362 # Get the peak memory usages per memory space
359363 peak_memory_usages_per_memory = find_peak_memory_usages_per_memory (
360364 executorch_prog .exported_program ().graph_module ,
365+ executorch_prog .exported_program ().graph_signature ,
361366 alloc_graph_input ,
362367 alloc_graph_output ,
363368 mem_constraints ,
@@ -393,6 +398,7 @@ def print_memory_planning_info(
393398 # Get the total peak memory usage across all memory spaces
394399 total_peak_memory_usage = find_peak_memory_usage (
395400 executorch_prog .exported_program ().graph_module ,
401+ executorch_prog .exported_program ().graph_signature ,
396402 alloc_graph_input ,
397403 alloc_graph_output ,
398404 mem_constraints ,
@@ -453,7 +459,17 @@ def _init_mem_algos(self) -> None:
453459 greedy_by_size_for_offset_calculation_with_hierarchy ,
454460 ]
455461
456- def __call__ (self , graph_module : torch .fx .GraphModule ) -> PassResult :
462+ def __call__ (
463+ self ,
464+ graph_module : torch .fx .GraphModule ,
465+ ) -> PassResult :
466+ return self .run (graph_module )
467+
468+ def run (
469+ self ,
470+ graph_module : torch .fx .GraphModule ,
471+ graph_signature : Optional [ExportGraphSignature ] = None ,
472+ ) -> PassResult :
457473 mem_constraints = MemConstraints (
458474 opt_level = self .opt_level ,
459475 alloc_graph_input = self .alloc_graph_input ,
@@ -475,6 +491,6 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
475491 alloc_graph_output = self .alloc_graph_output ,
476492 alignment = self .mem_alignment ,
477493 )
478- mem_planning (graph_module )
494+ mem_planning . run (graph_module , graph_signature )
479495
480496 return PassResult (graph_module , True )
0 commit comments