Skip to content

Commit 69081ff

Browse files
tarun292facebook-github-bot
authored andcommitted
Refactoring memory planning to allow running multiple algorithms (#8440)
Summary: This diff introduces `memory_planning_algorithm_suite` which is a method that allows us to iterate through multiple memory planning algorithms and pick the one that gives us the best results i.e. least memory consumed. The requirement for each of these algorithms is that they should generate a `MemoryAlgoResult` that contains the results of the memory planning done by that algorithm. These algos like before don't update the `TensorSpec` directly, but rather in `memory_planning_algorithm_suite` we figure out which algo gave us the best result and then update the `TensorSpec`'s with values (offsets etc.) returned by that algo. Reviewed By: JacobSzwejbka Differential Revision: D69515056
1 parent 1a9a59b commit 69081ff

File tree

4 files changed

+111
-22
lines changed

4 files changed

+111
-22
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from executorch.exir.backend.utils import DelegateMappingBuilder
4949

50-
from executorch.exir.memory_planning import greedy
50+
from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite
5151
from executorch.exir.pass_base import ExportPass, PassBase
5252

5353
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
@@ -199,11 +199,12 @@ def preprocess( # noqa: C901
199199
# Finally, apply dynamic shape passes and memory planning pass. These passes
200200
# must be applied only when the graph structure is finalized.
201201
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
202+
mem_planning_suite = partial(memory_planning_algorithm_suite, algo_list=[greedy_memory_planning])
202203
program = apply_passes(
203204
program,
204205
[
205206
ConstraintBasedSymShapeEvalPass(),
206-
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
207+
MemoryPlanningPass(memory_planning_algo=mem_planning_suite),
207208
],
208209
)
209210

exir/memory_planning.py

Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
import collections
10+
import functools
911
import itertools
1012
import logging
1113
import operator
@@ -522,6 +524,27 @@ class SharedObject:
522524
def __repr__(self) -> str:
523525
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
524526

527+
@dataclass
528+
class SpecAllocResult:
529+
""" These are the values that a memory plannig algorithm assigns to a spec.
530+
These are not directly written back into the spec object, but are used to
531+
track the allocation decisions and assigned back to the spec object in the
532+
end, based on which algorithm is picked as the best performing one.
533+
"""
534+
mem_id: int
535+
mem_obj_id: int
536+
mem_offset: int
537+
538+
@dataclass
539+
class MemoryAlgoResult:
540+
""" This is the result returned by a memory planning algorithm that is
541+
invoked by memory_planning_algorithm_suite. It contains the allocation
542+
decisions of that algorithm for all the specs, and the size of the buffer
543+
that was used for different memory hierarchies.
544+
"""
545+
spec_dict: Dict[TensorSpec, SpecAllocResult]
546+
bufsizes: List[int]
547+
525548

526549
def materialize_buffer(
527550
shared_objects: List[SharedObject], input_total_size: int = 0
@@ -711,7 +734,7 @@ def greedy(
711734
alloc_graph_input: bool = True,
712735
alloc_graph_output: bool = True,
713736
allow_overlapping_allocations: bool = True,
714-
) -> List[int]:
737+
) -> MemoryAlgoResult:
715738
r"""Greedy algorithm to allocate memory for tensors in the graph.
716739
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
717740
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -720,6 +743,7 @@ def greedy(
720743
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
721744
allocations disabled
722745
"""
746+
greedy_result = MemoryAlgoResult({}, [])
723747
# padding allocation with 64 bytes.
724748
# this requirement is really for XNNPACK backend which can read tensors
725749
# beyond the end of the tensor. This is done for performance
@@ -754,12 +778,16 @@ def greedy(
754778
sorted_specs.reverse()
755779

756780
for spec in sorted_specs:
781+
# Create an entry for this TensorSpec in the result object that we'll be
782+
# returning from this algorithm.
783+
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
757784
if spec.mem_id is None:
758-
spec.mem_id = 1
785+
spec_alloc_result.mem_id = 1
786+
else:
787+
spec_alloc_result.mem_id = spec.mem_id
788+
greedy_result.spec_dict[spec] = spec_alloc_result
759789
spec.realign(alignment)
760-
spec2obj[spec] = pick_shared_obj(
761-
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
762-
)
790+
spec2obj[spec] = pick_shared_obj(shared_objects[spec_alloc_result.mem_id], spec, allow_overlapping_allocations)
763791

764792
if len(shared_objects) == 0:
765793
# Cannot find any tensor in the graph that needs to be allocated.
@@ -787,24 +815,73 @@ def greedy(
787815
for sobj in shared_objects[mem_id]:
788816
for alloc in sobj.allocations:
789817
spec = alloc.spec
790-
alloc.spec.mem_obj_id = sobj.idx
791-
alloc.spec.mem_offset = sobj.offset + alloc.offset
818+
# Get the spec_alloc_result for this spec and update it with the
819+
# mem_obj_id and mem_offset generated by this algorithm.
820+
spec_alloc_result = greedy_result.spec_dict.get(spec, None)
821+
assert spec_alloc_result is not None, f"Spec {spec} not found."
822+
spec_alloc_result.mem_obj_id = sobj.idx
823+
spec_alloc_result.mem_offset = sobj.offset + alloc.offset
792824
num_specs_processed += 1
793825
assert (
794826
len(spec2obj) == num_specs_processed
795827
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
796828

797829
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
798-
return total_sizes
830+
greedy_result.bufsizes = total_sizes
831+
return greedy_result
799832

833+
def memory_planning_algorithm_suite(
834+
graph_module: torch.fx.GraphModule,
835+
alignment: int,
836+
graph_signature: Optional[ExportGraphSignature] = None,
837+
alloc_graph_input: bool = True,
838+
alloc_graph_output: bool = True,
839+
allow_overlapping_allocations: bool = True,
840+
algo_list: List[Callable[..., MemoryAlgoResult]] = [greedy],
841+
) -> List[int]:
842+
r"""
843+
Memory planning algorithm suite that runs a list of memory planning algorithms
844+
and returns the result of the algorithm that minimizes the total memory usage.
845+
"""
846+
mem_algo_results = {}
847+
for algo in algo_list:
848+
if isinstance(algo, functools.partial):
849+
name = algo.func.__name__
850+
else:
851+
name = getattr(algo, "__name__", None)
852+
# Run this memory planning algorithm and store the result in mem_algo_results
853+
# with the name of the algorithm as the key.
854+
mem_algo_results[name] = algo(
855+
graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output
856+
)
857+
858+
# All the algorithms should have the same number of buffers allocated.
859+
assert len({len(mem_algo_result.bufsizes) for mem_algo_result in mem_algo_results.values()}) == 1, "Different memory planning algorithms should have the same number of buffers allocated."
860+
861+
# Find the algorithm that minimizes the total memory usage.
862+
best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes))
863+
logging.debug(f"Best memory planning algo for this model is {best_algo}")
864+
bufsizes = mem_algo_results[best_algo].bufsizes
865+
866+
# Update the mem_id and mem_offset for each spec in the graph module based on the
867+
# values provided by the best memory planning algorithm.
868+
for spec in mem_algo_results[best_algo].spec_dict:
869+
spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec]
870+
spec.mem_id = spec_alloc_result.mem_id
871+
spec.mem_offset = spec_alloc_result.mem_offset
872+
spec.mem_obj_id = spec_alloc_result.mem_obj_id
873+
874+
return bufsizes
800875

801876
def naive(
802877
graph_module: torch.fx.GraphModule,
803878
alignment: int,
804879
graph_signature: Optional[ExportGraphSignature] = None,
805880
alloc_graph_input: bool = True,
806881
alloc_graph_output: bool = True,
807-
) -> List[int]:
882+
) -> MemoryAlgoResult:
883+
884+
naive_result = MemoryAlgoResult({}, [])
808885

809886
# allocate 'allocated' bytes from buffer with id mem_id.
810887
# return the starting offset of the allocated buffer.
@@ -826,16 +903,22 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
826903
ignore_graph_input=not alloc_graph_input,
827904
ignore_graph_output=not alloc_graph_output,
828905
):
906+
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
829907
# assume a single memory layer which has mem_id 1
830908
if spec.mem_id is None:
831-
spec.mem_id = 1
909+
spec_alloc_result.mem_id = 1
910+
else:
911+
spec_alloc_result.mem_id = spec.mem_id
912+
naive_result.spec_dict[spec] = spec_alloc_result
913+
832914
# allocate spec.allocated_memory bytes in the buffer
833915
# with the corresponding mem_id
834916
spec.realign(alignment)
835-
spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
917+
spec_alloc_result.mem_offset = _allocate_buf(bufsizes, spec_alloc_result.mem_id, spec.allocated_memory)
836918

837919
logging.debug(f"naive algorithm returns bufsizes: {bufsizes}")
838-
return bufsizes
920+
naive_result.bufsizes = bufsizes
921+
return naive_result
839922

840923

841924
def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
@@ -980,5 +1063,4 @@ def handle_submodule(
9801063
)
9811064

9821065
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
983-
9841066
return bufsizes

exir/passes/memory_planning_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
apply_algo,
1919
get_node_tensor_specs,
2020
greedy,
21+
memory_planning_algorithm_suite,
22+
MemoryAlgoResult,
2123
Verifier,
2224
)
2325
from executorch.exir.operator.convert import get_out_args_from_opoverload
@@ -40,7 +42,7 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
4042
class MemoryPlanningPass(PassBase):
4143
def __init__(
4244
self,
43-
memory_planning_algo: Callable[..., List[int]] = greedy,
45+
memory_planning_algo: Callable[..., List[int]] = memory_planning_algorithm_suite,
4446
allow_lifetime_and_storage_overlap: bool = False,
4547
alloc_graph_input: bool = True,
4648
alloc_graph_output: bool = True,

exir/tests/test_memory_planning.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import itertools
1010
import unittest
11+
from functools import partial
1112
from typing import Any, Callable, List, Optional, Tuple, Type
1213

1314
import executorch.exir as exir
@@ -19,6 +20,8 @@
1920
filter_nodes,
2021
get_node_tensor_specs,
2122
greedy,
23+
memory_planning_algorithm_suite,
24+
MemoryAlgoResult,
2225
naive,
2326
Verifier,
2427
)
@@ -234,7 +237,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
234237

235238
def maketest(
236239
module_cls: Type[torch.nn.Module],
237-
criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None,
240+
criteria: Optional[List[Tuple[Callable[..., MemoryAlgoResult], bool]]] = None,
238241
extra_check: Optional[Callable[..., None]] = None,
239242
use_functionalization: bool = True,
240243
alloc_graph_input: bool = True,
@@ -266,13 +269,13 @@ def wrapper(self: "TestMemoryPlanning") -> None:
266269
.exported_program()
267270
.graph_module
268271
)
269-
272+
mem_algo = partial(memory_planning_algorithm_suite, algo_list = [algo])
270273
graph_module = PassManager(
271274
passes=[
272275
SpecPropPass(),
273276
ToOutVarPass(),
274277
MemoryPlanningPass(
275-
algo,
278+
mem_algo,
276279
alloc_graph_input=alloc_graph_input,
277280
alloc_graph_output=alloc_graph_output,
278281
),
@@ -519,10 +522,11 @@ def test_multiple_pools(
519522
export(MultiplePoolsToyModel(), (torch.ones(1),), strict=True)
520523
)
521524

525+
mem_algo = partial(memory_planning_algorithm_suite, algo_list = [algo])
522526
edge_program.to_executorch(
523527
exir.ExecutorchBackendConfig(
524528
memory_planning_pass=CustomPoolMemoryPlanningPass(
525-
memory_planning_algo=algo,
529+
memory_planning_algo=mem_algo,
526530
alignment=1,
527531
),
528532
)
@@ -708,10 +712,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
708712
et_program = et.executorch_program
709713
inputs = et_program.execution_plan[0].inputs
710714
self.assertNotEqual(
711-
et_program.execution_plan[0] # pyre-ignore
715+
et_program.execution_plan[0]
712716
.values[inputs[0]]
713717
.val.allocation_info.memory_offset_low,
714-
et_program.execution_plan[0] # pyre-ignore
718+
et_program.execution_plan[0]
715719
.values[inputs[1]]
716720
.val.allocation_info.memory_offset_low,
717721
)

0 commit comments

Comments
 (0)