Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 127 additions & 51 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@
import itertools
import logging
import operator
import typing
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)

import torch
from executorch.exir import memory
Expand Down Expand Up @@ -960,7 +970,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
if bufsizes is None:
bufsizes = [0, 0]
bufsizes = typing.cast(List[int], bufsizes)
bufsizes = cast(List[int], bufsizes)

for spec in specs:
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
Expand Down Expand Up @@ -1062,33 +1072,119 @@ def insert_calls_to_free(
graph_module.recompile()


def _merge_bufsizes(bufsizes: list[int], new_bufsizes: list[int]) -> list[int]:
"""Combine two buffer size lists."""
if len(bufsizes) < len(new_bufsizes):
bufsizes.extend([0] * (len(new_bufsizes) - len(bufsizes)))
for i in range(len(new_bufsizes)):
bufsizes[i] = max(bufsizes[i], new_bufsizes[i])
return bufsizes


def _handle_submodule(
algo: Callable[..., list[int]],
parent_graph_module: torch.fx.GraphModule,
alignment: int,
submodule_node: torch.fx.Node,
graph_signature: Optional[ExportGraphSignature] = None,
alloc_graph_input: bool = False,
) -> list[int]:
"""Apply algo to nodes in a submodule of the graph module."""
assert submodule_node.op == "get_attr"
submodule = getattr(parent_graph_module, submodule_node.target)

logging.debug(f"Planning memory for submodule {submodule_node.name}...")
bufsizes = apply_algo(
algo,
submodule,
alignment,
graph_signature,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=True,
)
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
logging.debug(f"Buffer sizes for submodule {submodule_node.name}: {bufsizes}")
return bufsizes


def _apply_algo_to_submodules(
algo: Callable[..., list[int]],
graph_module: torch.fx.GraphModule,
alignment: int,
graph_signature: Optional[ExportGraphSignature] = None,
) -> list[int]:
"""Apply algo to map/cond/while nodes in the graph module.

This method will popuate graph_module.meta["non_const_buffer_sizes"] for
all submodules and return a bufsizes list that is the maximum size of all
buffers.
"""

# Bufsizes for submodules.
bufsizes: list[int] = []

def _handle(
submodule_node: torch.fx.Node,
alloc_graph_input: bool = False,
) -> None:
current_bufsizes = _handle_submodule(
algo,
graph_module,
alignment,
submodule_node,
graph_signature,
alloc_graph_input=alloc_graph_input,
)
nonlocal bufsizes
_merge_bufsizes(bufsizes, current_bufsizes)

for cond_node in get_cond_nodes(graph_module):
_handle(cast(torch.fx.Node, cond_node.args[1]))
_handle(cast(torch.fx.Node, cond_node.args[2]))

for while_node in get_while_nodes(graph_module):
_handle(cast(torch.fx.Node, while_node.args[0]))
_handle(cast(torch.fx.Node, while_node.args[1]))

for map_node in get_map_nodes(graph_module):
_handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True)

# TODO: We can handle delegates the same way as map/cond/while.
# Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates.

return bufsizes


def apply_algo(
algo: Callable[
...,
List[int],
],
algo: Callable[..., list[int]],
graph_module: torch.fx.GraphModule,
alignment: int,
graph_signature: Optional[ExportGraphSignature] = None,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
alloc_mutable_buffers: bool = True,
) -> List[int]:
) -> list[int]:
"""
Recursively apply algo to graph_module and its submodules for control flow.

Quite naively right now since it does not take the following optimizations
into considerating:
1. for conditional structure, true branch and false true does not overlap
in lifetime and can share tensor storage
2. tensors inside a submodule (e.g. true branch) has opportunities to share
storage with tensors in the outer module.
TODO: make these optimizations once we have some baseline working.
Algo implementation should handle one of two meta entries for submodules:
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
`algo` should start at the offset specified by this list;
OR
2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule.
`algo` should reserve the space specified by this list for the lifetime
of the submodule node (e.g. cond, while, map).

TODO: Missing optimizations:
1. To handle maps, we set `alloc_graph_input=True`, which allocates
appropriate space for mapped arg but ends up allocating extra space for
`operand` arg. The memory for operands is unused.
"""
# Extract the nodes and their lifespans from the graph_module
# Difficult to just filter the list of specs returned by this due to
# how we flag trainable weights.
_ = update_all_tensors_lifetime(graph_module, graph_signature)

# Filter specs based on alloc_graph_input and alloc_graph_output
specs = collect_specs_from_nodes(
graph_module.graph.nodes,
Expand All @@ -1099,55 +1195,35 @@ def apply_algo(
ignore_mutable_buffers=not alloc_mutable_buffers,
)

# Get temporary specs for submodules to set aside space during execution
# of submodules.
submodule_bufsizes = _apply_algo_to_submodules(
algo, graph_module, alignment, graph_signature
)

# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
# algos to work using `input_mem_buffer_sizes` or use
# `non_const_buffer_sizes` directly.
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
graph_module.input_mem_buffer_sizes = submodule_bufsizes

# Get extra padding for XNNPACK if needed
extra_padding = 0
if _contains_xnnpack_delegate(graph_module):
extra_padding = 64

# Pass the filtered specs to the algorithm
bufsizes: List[int] = algo(
bufsizes: list[int] = algo(
alignment,
specs,
graph_module,
graph_signature,
extra_padding,
)

insert_calls_to_free(graph_module, set(specs))

def handle_submodule(
submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
) -> None:
nonlocal bufsizes
assert submodule_nd.op == "get_attr"
submodule = getattr(graph_module, submodule_nd.target)
# memory planning for submodule need to be aware of the amount of
# buffer already allocated.
submodule.input_mem_buffer_sizes = bufsizes

bufsizes = apply_algo(
algo,
submodule,
alignment,
graph_signature,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=True,
)
submodule.meta.update({"non_const_buffer_sizes": bufsizes})

for cond_node in get_cond_nodes(graph_module):
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))

for while_node in get_while_nodes(graph_module):
handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
# TODO: Add test coverage for map operator once dynamo tracing is
# fully supported for this. T142287208
for map_node in get_map_nodes(graph_module):
handle_submodule(
typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
)
# pyre-ignore[6]: Incompatible parameter type [6]
# In call `insert_calls_to_free`, for 2nd positional argument, expected `Set[TensorSpec]` but got `Iterable[TensorSpec]`
insert_calls_to_free(graph_module, specs)

graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
return bufsizes
Loading
Loading