Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from executorch.exir.backend.utils import DelegateMappingBuilder

from executorch.exir.memory_planning import greedy
from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite
from executorch.exir.pass_base import ExportPass, PassBase

from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
Expand Down Expand Up @@ -199,11 +199,14 @@ def preprocess( # noqa: C901
# Finally, apply dynamic shape passes and memory planning pass. These passes
# must be applied only when the graph structure is finalized.
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
mem_planning_suite = partial(
memory_planning_algorithm_suite, algo_list=[greedy_memory_planning]
)
program = apply_passes(
program,
[
ConstraintBasedSymShapeEvalPass(),
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
MemoryPlanningPass(memory_planning_algo=mem_planning_suite),
],
)

Expand Down
129 changes: 118 additions & 11 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import functools
import itertools
import logging
import operator
Expand Down Expand Up @@ -523,6 +524,31 @@ def __repr__(self) -> str:
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"


@dataclass
class SpecAllocResult:
"""These are the values that a memory plannig algorithm assigns to a spec.
These are not directly written back into the spec object, but are used to
track the allocation decisions and assigned back to the spec object in the
end, based on which algorithm is picked as the best performing one.
"""

mem_id: int
mem_obj_id: int
mem_offset: int


@dataclass
class MemoryAlgoResult:
"""This is the result returned by a memory planning algorithm that is
invoked by memory_planning_algorithm_suite. It contains the allocation
decisions of that algorithm for all the specs, and the size of the buffer
that was used for different memory hierarchies.
"""

spec_dict: Dict[TensorSpec, SpecAllocResult]
bufsizes: List[int]


def materialize_buffer(
shared_objects: List[SharedObject], input_total_size: int = 0
) -> int:
Expand Down Expand Up @@ -711,7 +737,7 @@ def greedy(
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
allow_overlapping_allocations: bool = True,
) -> List[int]:
) -> MemoryAlgoResult:
r"""Greedy algorithm to allocate memory for tensors in the graph.
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
Expand All @@ -720,6 +746,7 @@ def greedy(
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
allocations disabled
"""
greedy_result = MemoryAlgoResult({}, [])
# padding allocation with 64 bytes.
# this requirement is really for XNNPACK backend which can read tensors
# beyond the end of the tensor. This is done for performance
Expand Down Expand Up @@ -754,11 +781,19 @@ def greedy(
sorted_specs.reverse()

for spec in sorted_specs:
# Create an entry for this TensorSpec in the result object that we'll be
# returning from this algorithm.
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
if spec.mem_id is None:
spec.mem_id = 1
spec_alloc_result.mem_id = 1
else:
spec_alloc_result.mem_id = spec.mem_id
greedy_result.spec_dict[spec] = spec_alloc_result
spec.realign(alignment)
spec2obj[spec] = pick_shared_obj(
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
shared_objects[spec_alloc_result.mem_id],
spec,
allow_overlapping_allocations,
)

if len(shared_objects) == 0:
Expand Down Expand Up @@ -787,24 +822,89 @@ def greedy(
for sobj in shared_objects[mem_id]:
for alloc in sobj.allocations:
spec = alloc.spec
alloc.spec.mem_obj_id = sobj.idx
alloc.spec.mem_offset = sobj.offset + alloc.offset
# Get the spec_alloc_result for this spec and update it with the
# mem_obj_id and mem_offset generated by this algorithm.
spec_alloc_result = greedy_result.spec_dict.get(spec, None)
assert spec_alloc_result is not None, f"Spec {spec} not found."
spec_alloc_result.mem_obj_id = sobj.idx
spec_alloc_result.mem_offset = sobj.offset + alloc.offset
num_specs_processed += 1
assert (
len(spec2obj) == num_specs_processed
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"

logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
return total_sizes
greedy_result.bufsizes = total_sizes
return greedy_result


def naive(
def memory_planning_algorithm_suite(
graph_module: torch.fx.GraphModule,
alignment: int,
graph_signature: Optional[ExportGraphSignature] = None,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
allow_overlapping_allocations: bool = True,
algo_list: Optional[List[Callable[..., MemoryAlgoResult]]] = None,
) -> List[int]:
r"""
Memory planning algorithm suite that runs a list of memory planning algorithms
and returns the result of the algorithm that minimizes the total memory usage.
"""
if algo_list is None:
algo_list = [greedy]
mem_algo_results = {}
for algo in algo_list:
if isinstance(algo, functools.partial):
name = algo.func.__name__
else:
name = getattr(algo, "__name__", None)
# Run this memory planning algorithm and store the result in mem_algo_results
# with the name of the algorithm as the key.
mem_algo_results[name] = algo(
graph_module,
alignment,
graph_signature,
alloc_graph_input,
alloc_graph_output,
)

# All the algorithms should have the same number of buffers allocated.
assert (
len(
{
len(mem_algo_result.bufsizes)
for mem_algo_result in mem_algo_results.values()
}
)
== 1
), "Different memory planning algorithms should have the same number of buffers allocated."

# Find the algorithm that minimizes the total memory usage.
best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes))
logging.debug(f"Best memory planning algo for this model is {best_algo}")
bufsizes = mem_algo_results[best_algo].bufsizes

# Update the mem_id and mem_offset for each spec in the graph module based on the
# values provided by the best memory planning algorithm.
for spec in mem_algo_results[best_algo].spec_dict:
spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec]
spec.mem_id = spec_alloc_result.mem_id
spec.mem_offset = spec_alloc_result.mem_offset
spec.mem_obj_id = spec_alloc_result.mem_obj_id

return bufsizes


def naive(
graph_module: torch.fx.GraphModule,
alignment: int,
graph_signature: Optional[ExportGraphSignature] = None,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
) -> MemoryAlgoResult:

naive_result = MemoryAlgoResult({}, [])

# allocate 'allocated' bytes from buffer with id mem_id.
# return the starting offset of the allocated buffer.
Expand All @@ -826,16 +926,24 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
ignore_graph_input=not alloc_graph_input,
ignore_graph_output=not alloc_graph_output,
):
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
# assume a single memory layer which has mem_id 1
if spec.mem_id is None:
spec.mem_id = 1
spec_alloc_result.mem_id = 1
else:
spec_alloc_result.mem_id = spec.mem_id
naive_result.spec_dict[spec] = spec_alloc_result

# allocate spec.allocated_memory bytes in the buffer
# with the corresponding mem_id
spec.realign(alignment)
spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
spec_alloc_result.mem_offset = _allocate_buf(
bufsizes, spec_alloc_result.mem_id, spec.allocated_memory
)

logging.debug(f"naive algorithm returns bufsizes: {bufsizes}")
return bufsizes
naive_result.bufsizes = bufsizes
return naive_result


def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
Expand Down Expand Up @@ -980,5 +1088,4 @@ def handle_submodule(
)

graph_module.meta.update({"non_const_buffer_sizes": bufsizes})

return bufsizes
6 changes: 4 additions & 2 deletions exir/passes/memory_planning_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_is_out_var_node,
apply_algo,
get_node_tensor_specs,
greedy,
memory_planning_algorithm_suite,
Verifier,
)
from executorch.exir.operator.convert import get_out_args_from_opoverload
Expand All @@ -40,7 +40,9 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
class MemoryPlanningPass(PassBase):
def __init__(
self,
memory_planning_algo: Callable[..., List[int]] = greedy,
memory_planning_algo: Callable[
..., List[int]
] = memory_planning_algorithm_suite,
allow_lifetime_and_storage_overlap: bool = False,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
Expand Down
12 changes: 8 additions & 4 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import itertools
import unittest
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Type

import executorch.exir as exir
Expand All @@ -19,6 +20,8 @@
filter_nodes,
get_node_tensor_specs,
greedy,
memory_planning_algorithm_suite,
MemoryAlgoResult,
naive,
Verifier,
)
Expand Down Expand Up @@ -234,7 +237,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:

def maketest(
module_cls: Type[torch.nn.Module],
criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None,
criteria: Optional[List[Tuple[Callable[..., MemoryAlgoResult], bool]]] = None,
extra_check: Optional[Callable[..., None]] = None,
use_functionalization: bool = True,
alloc_graph_input: bool = True,
Expand Down Expand Up @@ -266,13 +269,13 @@ def wrapper(self: "TestMemoryPlanning") -> None:
.exported_program()
.graph_module
)

mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo])
graph_module = PassManager(
passes=[
SpecPropPass(),
ToOutVarPass(),
MemoryPlanningPass(
algo,
mem_algo,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
),
Expand Down Expand Up @@ -519,10 +522,11 @@ def test_multiple_pools(
export(MultiplePoolsToyModel(), (torch.ones(1),), strict=True)
)

mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo])
edge_program.to_executorch(
exir.ExecutorchBackendConfig(
memory_planning_pass=CustomPoolMemoryPlanningPass(
memory_planning_algo=algo,
memory_planning_algo=mem_algo,
alignment=1,
),
)
Expand Down
Loading