Skip to content

Commit 6b72663

Browse files
authored
remove duplication
Differential Revision: D65236267 Pull Request resolved: #7186
1 parent 44e31fb commit 6b72663

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

backends/cadence/aot/TARGETS

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,37 @@ python_unittest(
374374
"//executorch/exir/dialects:lib",
375375
],
376376
)
377+
378+
379+
python_library(
380+
name = "memory_planning",
381+
srcs = [
382+
"memory_planning.py",
383+
],
384+
deps = [
385+
"fbsource//third-party/pypi/tabulate:tabulate",
386+
":memory_constraints",
387+
":pass_utils",
388+
"//caffe2:torch",
389+
"//executorch/exir:lib",
390+
"//executorch/exir:memory_planning",
391+
"//executorch/exir:tensor",
392+
"//executorch/exir/passes:lib",
393+
],
394+
)
395+
396+
397+
python_library(
398+
name = "memory_constraints",
399+
srcs = [
400+
"memory_constraints.py",
401+
],
402+
deps = [
403+
":pass_utils",
404+
":utils",
405+
"//caffe2:torch",
406+
"//executorch/exir:memory",
407+
"//executorch/exir:pass_manager",
408+
"//executorch/exir:tensor",
409+
],
410+
)

backends/cadence/aot/memory_planning.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import collections
810
import itertools
911
import logging
@@ -331,14 +333,15 @@ def find_peak_memory_usage(
331333
# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
332334
# +-------------------------------------+---------------+---------+
333335
def print_memory_planning_info(
334-
# pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
335336
executorch_prog: ExecutorchProgramManager,
336337
memory_config: MemoryConfig,
338+
opt_level: int,
337339
alloc_graph_input: bool,
338340
alloc_graph_output: bool,
339341
) -> None:
340342
# Get the peak memory usages per memory space
341343
mem_constraints = MemConstraints(
344+
opt_level=opt_level,
342345
alloc_graph_input=alloc_graph_input,
343346
alloc_graph_output=alloc_graph_output,
344347
)
@@ -406,6 +409,7 @@ class CadenceMemoryPlanning:
406409
def __init__(
407410
self,
408411
memory_config: MemoryConfig,
412+
opt_level: int,
409413
mem_algo: int,
410414
alloc_graph_input: bool = True,
411415
alloc_graph_output: bool = True,
@@ -421,6 +425,7 @@ def __init__(
421425
self._init_mem_algos()
422426

423427
self.memory_config = memory_config
428+
self.opt_level = opt_level
424429
self.mem_algo = mem_algo
425430
self.alloc_graph_input = alloc_graph_input
426431
self.alloc_graph_output = alloc_graph_output
@@ -434,6 +439,7 @@ def _init_mem_algos(self) -> None:
434439

435440
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
436441
mem_constraints = MemConstraints(
442+
opt_level=self.opt_level,
437443
alloc_graph_input=self.alloc_graph_input,
438444
alloc_graph_output=self.alloc_graph_output,
439445
)
@@ -448,7 +454,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
448454
# True.
449455
mem_planning = MemoryPlanningPass(
450456
algo,
451-
allow_lifetime_and_storage_overlap=False,
457+
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
452458
alloc_graph_input=self.alloc_graph_input,
453459
alloc_graph_output=self.alloc_graph_output,
454460
)

0 commit comments

Comments
 (0)