Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@
from executorch.exir.types import LeafValueSpec, ValueSpec
from torch._subclasses.fake_tensor import FakeTensor

from torch.export.exported_program import ExportedProgram
from torch.export.exported_program import ExportedProgram, ExportGraphSignature
from torch.fx.node import Node
from torch.utils import _pytree as pytree

from typing_extensions import TypeAlias
Expand Down Expand Up @@ -209,11 +210,11 @@ class _AbstractValue:
]


# pyre-ignore[13]: Attribute `node` is never initialized.
class _Emitter(torch.fx.Interpreter):
"""An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the
given traced torch.fx.GraphModule to the flatbuffer schema."""

# pyre-ignore[13]: Attribute `node` is never initialized.
node: torch.fx.Node

def __init__(
Expand Down Expand Up @@ -1633,6 +1634,28 @@ def placeholder( # noqa: C901
if isinstance(target, str) and isinstance(spec, TensorSpec):
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)

def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:
"""
Check if the node is buffer according to the provided graph signature.
If it is one return its fqn as well
"""
if node.op == "placeholder":
if isinstance(node.target, str):
if node.target in graph_signature.inputs_to_buffers:
return True
return False

# If the spec does not appear in the mutable section of the graph signature it still might
# overall be considered a mutable buffer if it has already been memory planned. This would
# suggest that the same abstract buffer is mutable in another entry point so we should
# compel it to be considered mutable in all entry points at emission just as the user did with
# memory planning.
is_mutable_buffer |= (
_is_buffer(self.node, self.exported_program.graph_signature)
and spec.mem_id is not None
and spec.mem_offset is not None
)

# If the placeholder has a constant_tag, it is external to the PTE file
# and requires a fqn and location=TensorDataLocation.EXTERNAL
if constant_tag is not None:
Expand Down
3 changes: 3 additions & 0 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def verify_graph_input_output(self) -> None:
assert len(specs) > 0, "Expect tensor specs"
specs = list(filter(lambda spec: not spec.const, specs))
if len(specs) == 0:
# all outputs are const so no need to allocate memory just say we suceeded
graph_output_allocated = self.alloc_graph_output
continue
allocated = any(
spec is None or spec.mem_offset is not None for spec in specs
Expand Down Expand Up @@ -408,6 +410,7 @@ def collect_specs_from_nodes( # noqa: C901
ignore_graph_input: bool = False,
ignore_graph_output: bool = False,
ignore_mutable_buffers: bool = False,
share_mutable_buffers: bool = False,
ignore_const: bool = True,
ignore_out_var_node: bool = True,
dedup: bool = True,
Expand Down
155 changes: 152 additions & 3 deletions exir/passes/memory_planning_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import logging
import warnings
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import torch
from executorch.exir._warnings import deprecated
Expand All @@ -16,14 +18,18 @@
from executorch.exir.memory_planning import (
_is_out_var_node,
apply_algo,
collect_specs_from_nodes,
filter_nodes,
get_node_tensor_specs,
MemoryPlanningAlgorithmSuite,
Verifier,
)
from executorch.exir.operator.convert import get_out_args_from_opoverload
from executorch.exir.pass_base import PassBase, PassResult
from executorch.exir.tensor import ALIGNMENT
from executorch.exir.tensor import ALIGNMENT, TensorSpec
from torch import fx
from torch.export.exported_program import ExportGraphSignature
from torch.fx import Node


# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
Expand All @@ -37,6 +43,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
return str(any_callable)


def _is_buffer(
node: Node, graph_signature: ExportGraphSignature
) -> Tuple[bool, Optional[str]]:
"""
Check if the node is buffer according to the provided graph signature.
If it is one return its fqn as well
"""
if node.op == "placeholder":
if isinstance(node.target, str):
if node.target in graph_signature.inputs_to_buffers:
fqn = graph_signature.inputs_to_buffers[node.target]
return (True, fqn)
return (False, None)


def _is_mutable_buffer(
node: Node, graph_signature: ExportGraphSignature
) -> Tuple[bool, Optional[str]]:
"""
Check if the node is mutable buffer according to the provided graph signature.
If it is one return its fqn as well
"""
if node.op == "placeholder":
if isinstance(node.target, str):
if node.target in graph_signature.inputs_to_buffers:
fqn = graph_signature.inputs_to_buffers[node.target]
# if the buffer is mutated then record that
if fqn in graph_signature.buffers_to_mutate.values():
return True, fqn
return False, None


def _get_spec_from_node(node: fx.Node) -> TensorSpec:
specs = get_node_tensor_specs(node)
return specs[0]


def _insert_mutable_buffer_specs(
state: "_MemoryPlanningState", gm: torch.fx.GraphModule, gs: ExportGraphSignature
):
for node in gm.graph.nodes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to have some docstring here

is_mutable, fqn = _is_mutable_buffer(node, gs)
if is_mutable:
assert fqn
spec = _get_spec_from_node(node)
if (
getattr(spec, "mem_id", None) is not None
or getattr(spec, "mem_offset", None) is not None
):
raise ValueError(
"Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
)
if fqn not in state.mutable_buffers.keys():
state.mutable_buffers[fqn] = set()
state.mutable_buffers[fqn].add(spec)
continue
is_buffer, fqn = _is_buffer(node, gs)
# If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
# So cache it and later double check that this buffer never appears mutable
if is_buffer:
assert fqn
spec = _get_spec_from_node(node)
if (
getattr(spec, "mem_id", None) is not None
or getattr(spec, "mem_offset", None) is not None
):
raise ValueError(
"Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
)
if fqn not in state.maybe_mutable_buffers.keys():
state.maybe_mutable_buffers[fqn] = set()
state.maybe_mutable_buffers[fqn].add(spec)


def _check_default_mem_ids(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
for spec in collect_specs_from_nodes(
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
None,
ignore_graph_input=False,
ignore_const=False,
ignore_out_var_node=False,
dedup=False,
do_assertion=False,
ignore_dynamic_unbound_tensor=False,
):
mem_id = getattr(spec, "mem_id", None)
if mem_id is not None and mem_id != 1:
raise ValueError(
"Cannot share mutable buffers if all other tensors are not on the default mem_id of 1"
)


@dataclass
class _MemoryPlanningState:
mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
maybe_mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict)
graph_modules: List[torch.fx.GraphModule] = field(default_factory=list)


class MemoryPlanningPass(PassBase):
def __init__(
self,
Expand All @@ -45,6 +151,7 @@ def __init__(
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
alloc_mutable_buffers: bool = True,
share_mutable_buffers: bool = False,
alignment: int = ALIGNMENT,
) -> None:
r"""
Expand All @@ -55,12 +162,18 @@ def __init__(
"""
if memory_planning_algo is None:
memory_planning_algo = MemoryPlanningAlgorithmSuite()
if share_mutable_buffers and not alloc_mutable_buffers:
raise ValueError(
"share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
)
self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo
self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
self.alloc_graph_input = alloc_graph_input
self.alloc_graph_output = alloc_graph_output
self.alloc_mutable_buffers = alloc_mutable_buffers
self.share_mutable_buffers = share_mutable_buffers
self.alignment = alignment
self.state = _MemoryPlanningState()

def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
"""
Expand Down Expand Up @@ -134,9 +247,17 @@ def run(
graph_signature,
self.alloc_graph_input,
self.alloc_graph_output,
self.alloc_mutable_buffers,
# If we are sharing the mutable buffers then do not allocate them in
# memory planning algo, instead collect all of the specs over all the entry
# points and then allocate them directly in the run_multimethod name call
self.alloc_mutable_buffers and not self.share_mutable_buffers,
)

if self.share_mutable_buffers and graph_signature is not None:
self.state.graph_modules.append(graph_module)
_check_default_mem_ids(graph_module)
_insert_mutable_buffer_specs(self.state, graph_module, graph_signature)

# TODO: make the verifier do the work recursively to handle
# control flow
verifier = Verifier(
Expand Down Expand Up @@ -164,3 +285,31 @@ def run(
# I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
verifier.verify_storage_reuse()
return PassResult(graph_module, True)

def run_multimethod(self):
"Resolve any memory planning done across entry points"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a docstring right?

Suggested change
"Resolve any memory planning done across entry points"
"""Resolve any memory planning done across entry points"""

if self.share_mutable_buffers:
arena: int = 0

# Every spec that shares an fqn is the same tensor! So we give it the same id and offset
# anywhere it appears.
for fqn, specs_set in self.state.mutable_buffers.items():
specs = list(specs_set)
# If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
if fqn in self.state.maybe_mutable_buffers.keys():
specs.extend(self.state.maybe_mutable_buffers[fqn])
for spec in specs:
# Assume a default memory planning placed all activations on 1, place shared state on 2.
spec.mem_id = 2
spec.realign(self.alignment)
# State is persistent, so the memory never overlaps.
spec.mem_offset = arena
# They should all be the same size since they are the same tensor, so just bump off the first.
arena += specs[0].allocated_memory

for graph_module in self.state.graph_modules:
if len(graph_module.meta["non_const_buffer_sizes"]) != 2:
raise ValueError(
"Cannot share mutable state if not using default memory ids"
)
graph_module.meta["non_const_buffer_sizes"].append(arena)
17 changes: 12 additions & 5 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,7 @@ def to_backend(
return epm

@et_logger("to_executorch")
def to_executorch(
def to_executorch( # noqa (FLAKE8) C901
self,
config: Optional[ExecutorchBackendConfig] = None,
) -> "ExecutorchProgramManager":
Expand Down Expand Up @@ -1745,11 +1745,9 @@ def to_executorch(
memory_planning_pass = config.memory_planning_pass
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
if hasattr(memory_planning_pass, "run"):
new_gm_res = memory_planning_pass.run( # pyre-ignore[16]
new_gm, new_signature
)
new_gm_res = memory_planning_pass.run(new_gm, new_signature)
else:
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
new_gm_res = memory_planning_pass(new_gm)

# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
Expand All @@ -1758,6 +1756,15 @@ def to_executorch(

_copy_module(program.graph_module, new_gm)
execution_programs[name] = program
# After running memory planning on all entry points we can run the cross entry point memory planning
if isinstance(config.memory_planning_pass, dict):
for memory_planning_pass in config.memory_planning_pass.values():
if hasattr(memory_planning_pass, "run_multimethod"):
memory_planning_pass.run_multimethod()
else:
memory_planning_pass = config.memory_planning_pass
if hasattr(memory_planning_pass, "run_multimethod"):
memory_planning_pass.run_multimethod()

et_pm = ExecutorchProgramManager(
execution_programs,
Expand Down
52 changes: 52 additions & 0 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
from executorch.exir import ExecutorchBackendConfig, to_edge
from executorch.exir.capture._capture import patch_forward
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import (
_do_user_inputs_exist,
Expand Down Expand Up @@ -93,6 +94,24 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
return (torch.randn(10), torch.randn(10))


class MultiEntryPointStatefulModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("state", torch.zeros(2, 2))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.state.add_(x).view(-1) * 2

def set_state(self, state: torch.Tensor) -> None:
self.state.copy_(state)

def get_state(self) -> torch.Tensor:
return self.state

def get_example_inputs(self) -> Tuple[torch.Tensor, ...]:
return (torch.ones(1),)


class ModelWithDifferentTensorSizes(torch.nn.Module):
def __init__(self) -> None:
super(ModelWithDifferentTensorSizes, self).__init__()
Expand Down Expand Up @@ -1081,3 +1100,36 @@ def test_multi_map(self) -> None:
verifier.storage_overlap(outer_spec, inner_spec),
f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap",
)

def test_multi_state_plan(self) -> None:
eager_module = MultiEntryPointStatefulModel().eval()
forward = export(eager_module, eager_module.get_example_inputs())
with patch_forward(eager_module, eager_module.get_state):
get_state = export(eager_module, ())
with patch_forward(eager_module, eager_module.set_state):
set_state = export(eager_module, (torch.zeros(1),))
edge = to_edge(
{"forward": forward, "set_state": set_state, "get_state": get_state}
)
et = edge.to_executorch(
ExecutorchBackendConfig(
memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True),
emit_mutable_buffer_names=True,
)
)
et_prog = et.executorch_program
count = 0
for plan in et_prog.execution_plan:
for value in plan.values:
if (
hasattr(value.val, "allocation_info")
and value.val.allocation_info is not None
and value.val.allocation_info.memory_id == 2
):
count += 1
self.assertEqual(value.val.allocation_info.memory_offset_low, 0)
self.assertTrue(value.val.extra_tensor_info is not None)
self.assertEqual(
value.val.extra_tensor_info.fully_qualified_name, "state"
)
self.assertEqual(count, 3)
Loading