diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 3cfcac13a8d..1c1c51bb58a 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -47,7 +47,7 @@ ) from executorch.exir.backend.utils import DelegateMappingBuilder -from executorch.exir.memory_planning import greedy +from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite from executorch.exir.pass_base import ExportPass, PassBase from executorch.exir.passes import MemoryPlanningPass, SpecPropPass @@ -199,11 +199,14 @@ def preprocess( # noqa: C901 # Finally, apply dynamic shape passes and memory planning pass. These passes # must be applied only when the graph structure is finalized. greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False) + mem_planning_suite = partial( + memory_planning_algorithm_suite, algo_list=[greedy_memory_planning] + ) program = apply_passes( program, [ ConstraintBasedSymShapeEvalPass(), - MemoryPlanningPass(memory_planning_algo=greedy_memory_planning), + MemoryPlanningPass(memory_planning_algo=mem_planning_suite), ], ) diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 1d5d0868c50..3f45276c9e2 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -6,6 +6,7 @@ # pyre-strict +import functools import itertools import logging import operator @@ -523,6 +524,31 @@ def __repr__(self) -> str: return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])" +@dataclass +class SpecAllocResult: + """These are the values that a memory plannig algorithm assigns to a spec. + These are not directly written back into the spec object, but are used to + track the allocation decisions and assigned back to the spec object in the + end, based on which algorithm is picked as the best performing one. + """ + + mem_id: int + mem_obj_id: int + mem_offset: int + + +@dataclass +class MemoryAlgoResult: + """This is the result returned by a memory planning algorithm that is + invoked by memory_planning_algorithm_suite. It contains the allocation + decisions of that algorithm for all the specs, and the size of the buffer + that was used for different memory hierarchies. + """ + + spec_dict: Dict[TensorSpec, SpecAllocResult] + bufsizes: List[int] + + def materialize_buffer( shared_objects: List[SharedObject], input_total_size: int = 0 ) -> int: @@ -711,7 +737,7 @@ def greedy( alloc_graph_input: bool = True, alloc_graph_output: bool = True, allow_overlapping_allocations: bool = True, -) -> List[int]: +) -> MemoryAlgoResult: r"""Greedy algorithm to allocate memory for tensors in the graph. alloc_graph_input: If set to true, the algorithm will allocate memory for graph input. alloc_graph_output: If set to true, the algorithm will allocate memory for graph output. @@ -720,6 +746,7 @@ def greedy( This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping allocations disabled """ + greedy_result = MemoryAlgoResult({}, []) # padding allocation with 64 bytes. # this requirement is really for XNNPACK backend which can read tensors # beyond the end of the tensor. This is done for performance @@ -754,11 +781,19 @@ def greedy( sorted_specs.reverse() for spec in sorted_specs: + # Create an entry for this TensorSpec in the result object that we'll be + # returning from this algorithm. + spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0)) if spec.mem_id is None: - spec.mem_id = 1 + spec_alloc_result.mem_id = 1 + else: + spec_alloc_result.mem_id = spec.mem_id + greedy_result.spec_dict[spec] = spec_alloc_result spec.realign(alignment) spec2obj[spec] = pick_shared_obj( - shared_objects[spec.mem_id], spec, allow_overlapping_allocations + shared_objects[spec_alloc_result.mem_id], + spec, + allow_overlapping_allocations, ) if len(shared_objects) == 0: @@ -787,24 +822,89 @@ def greedy( for sobj in shared_objects[mem_id]: for alloc in sobj.allocations: spec = alloc.spec - alloc.spec.mem_obj_id = sobj.idx - alloc.spec.mem_offset = sobj.offset + alloc.offset + # Get the spec_alloc_result for this spec and update it with the + # mem_obj_id and mem_offset generated by this algorithm. + spec_alloc_result = greedy_result.spec_dict.get(spec, None) + assert spec_alloc_result is not None, f"Spec {spec} not found." + spec_alloc_result.mem_obj_id = sobj.idx + spec_alloc_result.mem_offset = sobj.offset + alloc.offset num_specs_processed += 1 assert ( len(spec2obj) == num_specs_processed ), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs" logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}") - return total_sizes + greedy_result.bufsizes = total_sizes + return greedy_result -def naive( +def memory_planning_algorithm_suite( graph_module: torch.fx.GraphModule, alignment: int, graph_signature: Optional[ExportGraphSignature] = None, alloc_graph_input: bool = True, alloc_graph_output: bool = True, + allow_overlapping_allocations: bool = True, + algo_list: Optional[List[Callable[..., MemoryAlgoResult]]] = None, ) -> List[int]: + r""" + Memory planning algorithm suite that runs a list of memory planning algorithms + and returns the result of the algorithm that minimizes the total memory usage. + """ + if algo_list is None: + algo_list = [greedy] + mem_algo_results = {} + for algo in algo_list: + if isinstance(algo, functools.partial): + name = algo.func.__name__ + else: + name = getattr(algo, "__name__", None) + # Run this memory planning algorithm and store the result in mem_algo_results + # with the name of the algorithm as the key. + mem_algo_results[name] = algo( + graph_module, + alignment, + graph_signature, + alloc_graph_input, + alloc_graph_output, + ) + + # All the algorithms should have the same number of buffers allocated. + assert ( + len( + { + len(mem_algo_result.bufsizes) + for mem_algo_result in mem_algo_results.values() + } + ) + == 1 + ), "Different memory planning algorithms should have the same number of buffers allocated." + + # Find the algorithm that minimizes the total memory usage. + best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes)) + logging.debug(f"Best memory planning algo for this model is {best_algo}") + bufsizes = mem_algo_results[best_algo].bufsizes + + # Update the mem_id and mem_offset for each spec in the graph module based on the + # values provided by the best memory planning algorithm. + for spec in mem_algo_results[best_algo].spec_dict: + spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec] + spec.mem_id = spec_alloc_result.mem_id + spec.mem_offset = spec_alloc_result.mem_offset + spec.mem_obj_id = spec_alloc_result.mem_obj_id + + return bufsizes + + +def naive( + graph_module: torch.fx.GraphModule, + alignment: int, + graph_signature: Optional[ExportGraphSignature] = None, + alloc_graph_input: bool = True, + alloc_graph_output: bool = True, +) -> MemoryAlgoResult: + + naive_result = MemoryAlgoResult({}, []) # allocate 'allocated' bytes from buffer with id mem_id. # return the starting offset of the allocated buffer. @@ -826,16 +926,24 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int: ignore_graph_input=not alloc_graph_input, ignore_graph_output=not alloc_graph_output, ): + spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0)) # assume a single memory layer which has mem_id 1 if spec.mem_id is None: - spec.mem_id = 1 + spec_alloc_result.mem_id = 1 + else: + spec_alloc_result.mem_id = spec.mem_id + naive_result.spec_dict[spec] = spec_alloc_result + # allocate spec.allocated_memory bytes in the buffer # with the corresponding mem_id spec.realign(alignment) - spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory) + spec_alloc_result.mem_offset = _allocate_buf( + bufsizes, spec_alloc_result.mem_id, spec.allocated_memory + ) logging.debug(f"naive algorithm returns bufsizes: {bufsizes}") - return bufsizes + naive_result.bufsizes = bufsizes + return naive_result def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]: @@ -980,5 +1088,4 @@ def handle_submodule( ) graph_module.meta.update({"non_const_buffer_sizes": bufsizes}) - return bufsizes diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index f5431df431a..f4881e7ab71 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -17,7 +17,7 @@ _is_out_var_node, apply_algo, get_node_tensor_specs, - greedy, + memory_planning_algorithm_suite, Verifier, ) from executorch.exir.operator.convert import get_out_args_from_opoverload @@ -40,7 +40,9 @@ def _callable_name(any_callable: Callable[..., Any]) -> str: class MemoryPlanningPass(PassBase): def __init__( self, - memory_planning_algo: Callable[..., List[int]] = greedy, + memory_planning_algo: Callable[ + ..., List[int] + ] = memory_planning_algorithm_suite, allow_lifetime_and_storage_overlap: bool = False, alloc_graph_input: bool = True, alloc_graph_output: bool = True, diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index d885239acd8..8df0cfed0bf 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -8,6 +8,7 @@ import itertools import unittest +from functools import partial from typing import Any, Callable, List, Optional, Tuple, Type import executorch.exir as exir @@ -19,6 +20,8 @@ filter_nodes, get_node_tensor_specs, greedy, + memory_planning_algorithm_suite, + MemoryAlgoResult, naive, Verifier, ) @@ -234,7 +237,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: def maketest( module_cls: Type[torch.nn.Module], - criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None, + criteria: Optional[List[Tuple[Callable[..., MemoryAlgoResult], bool]]] = None, extra_check: Optional[Callable[..., None]] = None, use_functionalization: bool = True, alloc_graph_input: bool = True, @@ -266,13 +269,13 @@ def wrapper(self: "TestMemoryPlanning") -> None: .exported_program() .graph_module ) - + mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo]) graph_module = PassManager( passes=[ SpecPropPass(), ToOutVarPass(), MemoryPlanningPass( - algo, + mem_algo, alloc_graph_input=alloc_graph_input, alloc_graph_output=alloc_graph_output, ), @@ -519,10 +522,11 @@ def test_multiple_pools( export(MultiplePoolsToyModel(), (torch.ones(1),), strict=True) ) + mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo]) edge_program.to_executorch( exir.ExecutorchBackendConfig( memory_planning_pass=CustomPoolMemoryPlanningPass( - memory_planning_algo=algo, + memory_planning_algo=mem_algo, alignment=1, ), )