44# This source code is licensed under the BSD-style license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7+ # pyre-unsafe 
8+ 
79import  collections 
810import  itertools 
911import  logging 
@@ -331,14 +333,15 @@ def find_peak_memory_usage(
331333# | Peak memory usage across all spaces | 2380032 bytes | Node 86 | 
332334# +-------------------------------------+---------------+---------+ 
333335def  print_memory_planning_info (
334-     # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type. 
335336    executorch_prog : ExecutorchProgramManager ,
336337    memory_config : MemoryConfig ,
338+     opt_level : int ,
337339    alloc_graph_input : bool ,
338340    alloc_graph_output : bool ,
339341) ->  None :
340342    # Get the peak memory usages per memory space 
341343    mem_constraints  =  MemConstraints (
344+         opt_level = opt_level ,
342345        alloc_graph_input = alloc_graph_input ,
343346        alloc_graph_output = alloc_graph_output ,
344347    )
@@ -406,6 +409,7 @@ class CadenceMemoryPlanning:
406409    def  __init__ (
407410        self ,
408411        memory_config : MemoryConfig ,
412+         opt_level : int ,
409413        mem_algo : int ,
410414        alloc_graph_input : bool  =  True ,
411415        alloc_graph_output : bool  =  True ,
@@ -421,6 +425,7 @@ def __init__(
421425        self ._init_mem_algos ()
422426
423427        self .memory_config  =  memory_config 
428+         self .opt_level  =  opt_level 
424429        self .mem_algo  =  mem_algo 
425430        self .alloc_graph_input  =  alloc_graph_input 
426431        self .alloc_graph_output  =  alloc_graph_output 
@@ -434,6 +439,7 @@ def _init_mem_algos(self) -> None:
434439
435440    def  __call__ (self , graph_module : torch .fx .GraphModule ) ->  PassResult :
436441        mem_constraints  =  MemConstraints (
442+             opt_level = self .opt_level ,
437443            alloc_graph_input = self .alloc_graph_input ,
438444            alloc_graph_output = self .alloc_graph_output ,
439445        )
@@ -448,7 +454,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
448454        # True. 
449455        mem_planning  =  MemoryPlanningPass (
450456            algo ,
451-             allow_lifetime_and_storage_overlap = False ,
457+             allow_lifetime_and_storage_overlap = ( self . opt_level   >=   2 ) ,
452458            alloc_graph_input = self .alloc_graph_input ,
453459            alloc_graph_output = self .alloc_graph_output ,
454460        )
0 commit comments