Skip to content

Commit 8eed6f3

Browse files
tarun292facebook-github-bot
authored andcommitted
Memory planning updates (#8440)
Summary: Pull Request resolved: #8440 Differential Revision: D69515056
1 parent c8311e6 commit 8eed6f3

File tree

5 files changed

+71
-28
lines changed

5 files changed

+71
-28
lines changed

backends/cadence/aot/memory_planning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
469469
# (output) tensors if alloc_graph_input (alloc_graph_output) is
470470
# True.
471471
mem_planning = MemoryPlanningPass(
472-
algo,
472+
[algo],
473473
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
474474
alloc_graph_input=self.alloc_graph_input,
475475
alloc_graph_output=self.alloc_graph_output,

backends/vulkan/vulkan_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def preprocess( # noqa: C901
203203
program,
204204
[
205205
ConstraintBasedSymShapeEvalPass(),
206-
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
206+
MemoryPlanningPass(memory_planning_algo=[greedy_memory_planning]),
207207
],
208208
)
209209

exir/memory_planning.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import collections
910
import itertools
1011
import logging
1112
import operator
@@ -503,6 +504,17 @@ class SharedObject:
503504
def __repr__(self) -> str:
504505
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
505506

507+
@dataclass
508+
class SpecAllocResult:
509+
mem_id: int
510+
mem_obj_id: int
511+
mem_offset: int
512+
513+
@dataclass
514+
class MemoryAlgoResult:
515+
spec_dict: Dict[TensorSpec, SpecAllocResult]
516+
bufsizes: List[int]
517+
506518

507519
def materialize_buffer(
508520
shared_objects: List[SharedObject], input_total_size: int = 0
@@ -692,7 +704,7 @@ def greedy(
692704
alloc_graph_input: bool = True,
693705
alloc_graph_output: bool = True,
694706
allow_overlapping_allocations: bool = True,
695-
) -> List[int]:
707+
) -> MemoryAlgoResult:
696708
r"""Greedy algorithm to allocate memory for tensors in the graph.
697709
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
698710
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -701,6 +713,7 @@ def greedy(
701713
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
702714
allocations disabled
703715
"""
716+
greedy_result = MemoryAlgoResult({}, [])
704717
# padding allocation with 64 bytes.
705718
# this requirement is really for XNNPACK backend which can read tensors
706719
# beyond the end of the tensor. This is done for performance
@@ -735,12 +748,14 @@ def greedy(
735748
sorted_specs.reverse()
736749

737750
for spec in sorted_specs:
751+
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
738752
if spec.mem_id is None:
739-
spec.mem_id = 1
753+
spec_alloc_result.mem_id = 1
754+
else:
755+
spec_alloc_result.mem_id = spec.mem_id
756+
greedy_result.spec_dict[spec] = spec_alloc_result
740757
spec.realign(alignment)
741-
spec2obj[spec] = pick_shared_obj(
742-
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
743-
)
758+
spec2obj[spec] = pick_shared_obj(shared_objects[spec_alloc_result.mem_id], spec, allow_overlapping_allocations)
744759

745760
if len(shared_objects) == 0:
746761
# Cannot find any tensor in the graph that needs to be allocated.
@@ -768,15 +783,18 @@ def greedy(
768783
for sobj in shared_objects[mem_id]:
769784
for alloc in sobj.allocations:
770785
spec = alloc.spec
771-
alloc.spec.mem_obj_id = sobj.idx
772-
alloc.spec.mem_offset = sobj.offset + alloc.offset
786+
spec_alloc_result = greedy_result.spec_dict.get(spec, None)
787+
assert spec_alloc_result is not None, f"Spec {spec} not found."
788+
spec_alloc_result.mem_obj_id = sobj.idx
789+
spec_alloc_result.mem_offset = sobj.offset + alloc.offset
773790
num_specs_processed += 1
774791
assert (
775792
len(spec2obj) == num_specs_processed
776793
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
777794

778795
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
779-
return total_sizes
796+
greedy_result.bufsizes = total_sizes
797+
return greedy_result
780798

781799

782800
def naive(
@@ -785,7 +803,9 @@ def naive(
785803
graph_signature: Optional[ExportGraphSignature] = None,
786804
alloc_graph_input: bool = True,
787805
alloc_graph_output: bool = True,
788-
) -> List[int]:
806+
) -> MemoryAlgoResult:
807+
808+
naive_result = MemoryAlgoResult({}, [])
789809

790810
# allocate 'allocated' bytes from buffer with id mem_id.
791811
# return the starting offset of the allocated buffer.
@@ -807,16 +827,22 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
807827
ignore_graph_input=not alloc_graph_input,
808828
ignore_graph_output=not alloc_graph_output,
809829
):
830+
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
810831
# assume a single memory layer which has mem_id 1
811832
if spec.mem_id is None:
812-
spec.mem_id = 1
833+
spec_alloc_result.mem_id = 1
834+
else:
835+
spec_alloc_result.mem_id = spec.mem_id
836+
naive_result.spec_dict[spec] = spec_alloc_result
837+
813838
# allocate spec.allocated_memory bytes in the buffer
814839
# with the corresponding mem_id
815840
spec.realign(alignment)
816-
spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
841+
spec_alloc_result.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
817842

818843
logging.debug(f"naive algorithm returns bufsizes: {bufsizes}")
819-
return bufsizes
844+
naive_result.bufsizes = bufsizes
845+
return naive_result
820846

821847

822848
def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
@@ -899,10 +925,10 @@ def insert_calls_to_free(
899925

900926

901927
def apply_algo(
902-
algo: Callable[
928+
algo_list: List[Callable[
903929
[torch.fx.GraphModule, int, Optional[ExportGraphSignature], bool, bool],
904-
List[int],
905-
],
930+
MemoryAlgoResult,
931+
]],
906932
graph_module: torch.fx.GraphModule,
907933
alignment: int,
908934
graph_signature: Optional[ExportGraphSignature] = None,
@@ -922,9 +948,24 @@ def apply_algo(
922948
"""
923949

924950
specs = update_all_tensors_lifetime(graph_module, graph_signature)
925-
bufsizes: List[int] = algo(
926-
graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output
927-
)
951+
mem_algo_results = {}
952+
for algo in algo_list:
953+
mem_algo_results[getattr(algo, "__name__")] = algo(
954+
graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output
955+
)
956+
957+
assert len({len(mem_algo_result.bufsizes) for mem_algo_result in mem_algo_results.values()}) == 1, "Different memory planning algorithms should have the same number of buffers allocated."
958+
959+
best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes))
960+
logging.debug(f"Best memory planning algo for this model is {best_algo}")
961+
bufsizes = mem_algo_results[best_algo].bufsizes
962+
963+
for spec in mem_algo_results[best_algo].spec_dict:
964+
spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec]
965+
spec.mem_id = spec_alloc_result.mem_id
966+
spec.mem_offset = spec_alloc_result.mem_offset
967+
spec.mem_obj_id = spec_alloc_result.mem_obj_id
968+
928969
insert_calls_to_free(graph_module, specs)
929970

930971
def handle_submodule(
@@ -937,7 +978,7 @@ def handle_submodule(
937978
# buffer already allocated.
938979
submodule.input_mem_buffer_sizes = bufsizes
939980
bufsizes = apply_algo(
940-
algo,
981+
algo_list,
941982
submodule,
942983
alignment,
943984
graph_signature,
@@ -961,5 +1002,5 @@ def handle_submodule(
9611002
)
9621003

9631004
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
964-
1005+
print(f"bufsizes = {bufsizes}")
9651006
return bufsizes

exir/passes/memory_planning_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
apply_algo,
1919
get_node_tensor_specs,
2020
greedy,
21+
MemoryAlgoResult,
2122
Verifier,
2223
)
2324
from executorch.exir.operator.convert import get_out_args_from_opoverload
@@ -40,7 +41,7 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
4041
class MemoryPlanningPass(PassBase):
4142
def __init__(
4243
self,
43-
memory_planning_algo: Callable[..., List[int]] = greedy,
44+
memory_planning_algo: List[Callable[..., Any]] = [greedy],
4445
allow_lifetime_and_storage_overlap: bool = False,
4546
alloc_graph_input: bool = True,
4647
alloc_graph_output: bool = True,

exir/tests/test_memory_planning.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
filter_nodes,
2020
get_node_tensor_specs,
2121
greedy,
22+
MemoryAlgoResult,
2223
naive,
2324
Verifier,
2425
)
@@ -234,7 +235,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
234235

235236
def maketest(
236237
module_cls: Type[torch.nn.Module],
237-
criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None,
238+
criteria: Optional[List[Tuple[Callable[..., MemoryAlgoResult], bool]]] = None,
238239
extra_check: Optional[Callable[..., None]] = None,
239240
use_functionalization: bool = True,
240241
alloc_graph_input: bool = True,
@@ -272,7 +273,7 @@ def wrapper(self: "TestMemoryPlanning") -> None:
272273
SpecPropPass(),
273274
ToOutVarPass(),
274275
MemoryPlanningPass(
275-
algo,
276+
[algo],
276277
alloc_graph_input=alloc_graph_input,
277278
alloc_graph_output=alloc_graph_output,
278279
),
@@ -522,7 +523,7 @@ def test_multiple_pools(
522523
edge_program.to_executorch(
523524
exir.ExecutorchBackendConfig(
524525
memory_planning_pass=CustomPoolMemoryPlanningPass(
525-
memory_planning_algo=algo,
526+
memory_planning_algo=[algo],
526527
alignment=1,
527528
),
528529
)
@@ -708,10 +709,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
708709
et_program = et.executorch_program
709710
inputs = et_program.execution_plan[0].inputs
710711
self.assertNotEqual(
711-
et_program.execution_plan[0] # pyre-ignore
712+
et_program.execution_plan[0]
712713
.values[inputs[0]]
713714
.val.allocation_info.memory_offset_low,
714-
et_program.execution_plan[0] # pyre-ignore
715+
et_program.execution_plan[0]
715716
.values[inputs[1]]
716717
.val.allocation_info.memory_offset_low,
717718
)

0 commit comments

Comments
 (0)