-
Notifications
You must be signed in to change notification settings - Fork 741
Add option in memory planning to put shared state on same location across entry points #14230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
facebook-github-bot
merged 1 commit into
pytorch:main
from
JacobSzwejbka:export-D82250153
Sep 18, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,10 +4,12 @@ | |||||
| # This source code is licensed under the BSD-style license found in the | ||||||
| # LICENSE file in the root directory of this source tree. | ||||||
|
|
||||||
| import itertools | ||||||
| import logging | ||||||
| import warnings | ||||||
| from dataclasses import dataclass, field | ||||||
| from functools import partial | ||||||
| from typing import Any, Callable, List, Optional | ||||||
| from typing import Any, Callable, Dict, List, Optional, Set, Tuple | ||||||
|
|
||||||
| import torch | ||||||
| from executorch.exir._warnings import deprecated | ||||||
|
|
@@ -16,14 +18,18 @@ | |||||
| from executorch.exir.memory_planning import ( | ||||||
| _is_out_var_node, | ||||||
| apply_algo, | ||||||
| collect_specs_from_nodes, | ||||||
| filter_nodes, | ||||||
| get_node_tensor_specs, | ||||||
| MemoryPlanningAlgorithmSuite, | ||||||
| Verifier, | ||||||
| ) | ||||||
| from executorch.exir.operator.convert import get_out_args_from_opoverload | ||||||
| from executorch.exir.pass_base import PassBase, PassResult | ||||||
| from executorch.exir.tensor import ALIGNMENT | ||||||
| from executorch.exir.tensor import ALIGNMENT, TensorSpec | ||||||
| from torch import fx | ||||||
| from torch.export.exported_program import ExportGraphSignature | ||||||
| from torch.fx import Node | ||||||
|
|
||||||
|
|
||||||
| # copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function | ||||||
|
|
@@ -37,6 +43,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str: | |||||
| return str(any_callable) | ||||||
|
|
||||||
|
|
||||||
| def _is_buffer( | ||||||
| node: Node, graph_signature: ExportGraphSignature | ||||||
| ) -> Tuple[bool, Optional[str]]: | ||||||
| """ | ||||||
| Check if the node is buffer according to the provided graph signature. | ||||||
| If it is one return its fqn as well | ||||||
| """ | ||||||
| if node.op == "placeholder": | ||||||
| if isinstance(node.target, str): | ||||||
| if node.target in graph_signature.inputs_to_buffers: | ||||||
| fqn = graph_signature.inputs_to_buffers[node.target] | ||||||
| return (True, fqn) | ||||||
| return (False, None) | ||||||
|
|
||||||
|
|
||||||
| def _is_mutable_buffer( | ||||||
| node: Node, graph_signature: ExportGraphSignature | ||||||
| ) -> Tuple[bool, Optional[str]]: | ||||||
| """ | ||||||
| Check if the node is mutable buffer according to the provided graph signature. | ||||||
| If it is one return its fqn as well | ||||||
| """ | ||||||
| if node.op == "placeholder": | ||||||
| if isinstance(node.target, str): | ||||||
| if node.target in graph_signature.inputs_to_buffers: | ||||||
| fqn = graph_signature.inputs_to_buffers[node.target] | ||||||
| # if the buffer is mutated then record that | ||||||
| if fqn in graph_signature.buffers_to_mutate.values(): | ||||||
| return True, fqn | ||||||
| return False, None | ||||||
|
|
||||||
|
|
||||||
| def _get_spec_from_node(node: fx.Node) -> TensorSpec: | ||||||
| specs = get_node_tensor_specs(node) | ||||||
| return specs[0] | ||||||
|
|
||||||
|
|
||||||
| def _insert_mutable_buffer_specs( | ||||||
| state: "_MemoryPlanningState", gm: torch.fx.GraphModule, gs: ExportGraphSignature | ||||||
| ): | ||||||
| for node in gm.graph.nodes: | ||||||
| is_mutable, fqn = _is_mutable_buffer(node, gs) | ||||||
| if is_mutable: | ||||||
| assert fqn | ||||||
| spec = _get_spec_from_node(node) | ||||||
| if ( | ||||||
| getattr(spec, "mem_id", None) is not None | ||||||
| or getattr(spec, "mem_offset", None) is not None | ||||||
| ): | ||||||
| raise ValueError( | ||||||
| "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned" | ||||||
| ) | ||||||
| if fqn not in state.mutable_buffers.keys(): | ||||||
| state.mutable_buffers[fqn] = set() | ||||||
| state.mutable_buffers[fqn].add(spec) | ||||||
| continue | ||||||
| is_buffer, fqn = _is_buffer(node, gs) | ||||||
| # If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state() | ||||||
| # So cache it and later double check that this buffer never appears mutable | ||||||
| if is_buffer: | ||||||
| assert fqn | ||||||
| spec = _get_spec_from_node(node) | ||||||
| if ( | ||||||
| getattr(spec, "mem_id", None) is not None | ||||||
| or getattr(spec, "mem_offset", None) is not None | ||||||
| ): | ||||||
| raise ValueError( | ||||||
| "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned" | ||||||
| ) | ||||||
| if fqn not in state.maybe_mutable_buffers.keys(): | ||||||
| state.maybe_mutable_buffers[fqn] = set() | ||||||
| state.maybe_mutable_buffers[fqn].add(spec) | ||||||
|
|
||||||
|
|
||||||
| def _check_default_mem_ids(gm: torch.fx.GraphModule): | ||||||
| for node in gm.graph.nodes: | ||||||
| for spec in collect_specs_from_nodes( | ||||||
| filter_nodes(itertools.chain([node], node.args, node.kwargs.values())), | ||||||
| None, | ||||||
| ignore_graph_input=False, | ||||||
| ignore_const=False, | ||||||
| ignore_out_var_node=False, | ||||||
| dedup=False, | ||||||
| do_assertion=False, | ||||||
| ignore_dynamic_unbound_tensor=False, | ||||||
| ): | ||||||
| mem_id = getattr(spec, "mem_id", None) | ||||||
| if mem_id is not None and mem_id != 1: | ||||||
| raise ValueError( | ||||||
| "Cannot share mutable buffers if all other tensors are not on the default mem_id of 1" | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class _MemoryPlanningState: | ||||||
| mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict) | ||||||
| maybe_mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict) | ||||||
| graph_modules: List[torch.fx.GraphModule] = field(default_factory=list) | ||||||
|
|
||||||
|
|
||||||
| class MemoryPlanningPass(PassBase): | ||||||
| def __init__( | ||||||
| self, | ||||||
|
|
@@ -45,6 +151,7 @@ def __init__( | |||||
| alloc_graph_input: bool = True, | ||||||
| alloc_graph_output: bool = True, | ||||||
| alloc_mutable_buffers: bool = True, | ||||||
| share_mutable_buffers: bool = False, | ||||||
| alignment: int = ALIGNMENT, | ||||||
| ) -> None: | ||||||
| r""" | ||||||
|
|
@@ -55,12 +162,18 @@ def __init__( | |||||
| """ | ||||||
| if memory_planning_algo is None: | ||||||
| memory_planning_algo = MemoryPlanningAlgorithmSuite() | ||||||
| if share_mutable_buffers and not alloc_mutable_buffers: | ||||||
| raise ValueError( | ||||||
| "share_mutable_buffers is only meaningful when alloc_mutable_buffers is True" | ||||||
| ) | ||||||
| self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo | ||||||
| self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap | ||||||
| self.alloc_graph_input = alloc_graph_input | ||||||
| self.alloc_graph_output = alloc_graph_output | ||||||
| self.alloc_mutable_buffers = alloc_mutable_buffers | ||||||
| self.share_mutable_buffers = share_mutable_buffers | ||||||
| self.alignment = alignment | ||||||
| self.state = _MemoryPlanningState() | ||||||
|
|
||||||
| def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: | ||||||
| """ | ||||||
|
|
@@ -134,9 +247,17 @@ def run( | |||||
| graph_signature, | ||||||
| self.alloc_graph_input, | ||||||
| self.alloc_graph_output, | ||||||
| self.alloc_mutable_buffers, | ||||||
| # If we are sharing the mutable buffers then do not allocate them in | ||||||
| # memory planning algo, instead collect all of the specs over all the entry | ||||||
| # points and then allocate them directly in the run_multimethod name call | ||||||
| self.alloc_mutable_buffers and not self.share_mutable_buffers, | ||||||
| ) | ||||||
|
|
||||||
| if self.share_mutable_buffers and graph_signature is not None: | ||||||
| self.state.graph_modules.append(graph_module) | ||||||
| _check_default_mem_ids(graph_module) | ||||||
| _insert_mutable_buffer_specs(self.state, graph_module, graph_signature) | ||||||
|
|
||||||
| # TODO: make the verifier do the work recursively to handle | ||||||
| # control flow | ||||||
| verifier = Verifier( | ||||||
|
|
@@ -164,3 +285,31 @@ def run( | |||||
| # I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function | ||||||
| verifier.verify_storage_reuse() | ||||||
| return PassResult(graph_module, True) | ||||||
|
|
||||||
| def run_multimethod(self): | ||||||
| "Resolve any memory planning done across entry points" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be a docstring right?
Suggested change
|
||||||
| if self.share_mutable_buffers: | ||||||
| arena: int = 0 | ||||||
|
|
||||||
| # Every spec that shares an fqn is the same tensor! So we give it the same id and offset | ||||||
| # anywhere it appears. | ||||||
| for fqn, specs_set in self.state.mutable_buffers.items(): | ||||||
| specs = list(specs_set) | ||||||
| # If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable. | ||||||
| if fqn in self.state.maybe_mutable_buffers.keys(): | ||||||
| specs.extend(self.state.maybe_mutable_buffers[fqn]) | ||||||
| for spec in specs: | ||||||
| # Assume a default memory planning placed all activations on 1, place shared state on 2. | ||||||
| spec.mem_id = 2 | ||||||
| spec.realign(self.alignment) | ||||||
| # State is persistent, so the memory never overlaps. | ||||||
| spec.mem_offset = arena | ||||||
| # They should all be the same size since they are the same tensor, so just bump off the first. | ||||||
| arena += specs[0].allocated_memory | ||||||
|
|
||||||
| for graph_module in self.state.graph_modules: | ||||||
| if len(graph_module.meta["non_const_buffer_sizes"]) != 2: | ||||||
| raise ValueError( | ||||||
| "Cannot share mutable state if not using default memory ids" | ||||||
| ) | ||||||
| graph_module.meta["non_const_buffer_sizes"].append(arena) | ||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to have some docstring here