44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ # pyre-unsafe
8+
79import collections
810import itertools
911import logging
@@ -331,14 +333,15 @@ def find_peak_memory_usage(
331333# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
332334# +-------------------------------------+---------------+---------+
333335def print_memory_planning_info (
334- # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
335336 executorch_prog : ExecutorchProgramManager ,
336337 memory_config : MemoryConfig ,
338+ opt_level : int ,
337339 alloc_graph_input : bool ,
338340 alloc_graph_output : bool ,
339341) -> None :
340342 # Get the peak memory usages per memory space
341343 mem_constraints = MemConstraints (
344+ opt_level = opt_level ,
342345 alloc_graph_input = alloc_graph_input ,
343346 alloc_graph_output = alloc_graph_output ,
344347 )
@@ -406,6 +409,7 @@ class CadenceMemoryPlanning:
406409 def __init__ (
407410 self ,
408411 memory_config : MemoryConfig ,
412+ opt_level : int ,
409413 mem_algo : int ,
410414 alloc_graph_input : bool = True ,
411415 alloc_graph_output : bool = True ,
@@ -421,6 +425,7 @@ def __init__(
421425 self ._init_mem_algos ()
422426
423427 self .memory_config = memory_config
428+ self .opt_level = opt_level
424429 self .mem_algo = mem_algo
425430 self .alloc_graph_input = alloc_graph_input
426431 self .alloc_graph_output = alloc_graph_output
@@ -434,6 +439,7 @@ def _init_mem_algos(self) -> None:
434439
435440 def __call__ (self , graph_module : torch .fx .GraphModule ) -> PassResult :
436441 mem_constraints = MemConstraints (
442+ opt_level = self .opt_level ,
437443 alloc_graph_input = self .alloc_graph_input ,
438444 alloc_graph_output = self .alloc_graph_output ,
439445 )
@@ -448,7 +454,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
448454 # True.
449455 mem_planning = MemoryPlanningPass (
450456 algo ,
451- allow_lifetime_and_storage_overlap = False ,
457+ allow_lifetime_and_storage_overlap = ( self . opt_level >= 2 ) ,
452458 alloc_graph_input = self .alloc_graph_input ,
453459 alloc_graph_output = self .alloc_graph_output ,
454460 )
0 commit comments