44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import itertools
78import logging
89import warnings
10+ from dataclasses import dataclass , field
911from functools import partial
10- from typing import Any , Callable , List , Optional
12+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple
1113
1214import torch
1315from executorch .exir ._warnings import deprecated
1618from executorch .exir .memory_planning import (
1719 _is_out_var_node ,
1820 apply_algo ,
21+ collect_specs_from_nodes ,
22+ filter_nodes ,
23+ get_node_tensor_specs ,
1924 get_node_tensor_specs ,
2025 MemoryPlanningAlgorithmSuite ,
2126 Verifier ,
2227)
2328from executorch .exir .operator .convert import get_out_args_from_opoverload
2429from executorch .exir .pass_base import PassBase , PassResult
25- from executorch .exir .tensor import ALIGNMENT
30+ from executorch .exir .tensor import ALIGNMENT , TensorSpec
31+ from torch import fx
2632from torch .export .exported_program import ExportGraphSignature
33+ from torch .fx import Node
2734
2835
2936# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
@@ -37,6 +44,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
3744 return str (any_callable )
3845
3946
47+ def _is_buffer (
48+ node : Node , graph_signature : ExportGraphSignature
49+ ) -> Tuple [bool , Optional [str ]]:
50+ """
51+ Check if the node is buffer according to the provided graph signature.
52+ If it is one return its fqn as well
53+ """
54+ if node .op == "placeholder" :
55+ if isinstance (node .target , str ):
56+ if node .target in graph_signature .inputs_to_buffers :
57+ fqn = graph_signature .inputs_to_buffers [node .target ]
58+ return (True , fqn )
59+ return (False , None )
60+
61+
62+ def _is_mutable_buffer (
63+ node : Node , graph_signature : ExportGraphSignature
64+ ) -> Tuple [bool , Optional [str ]]:
65+ """
66+ Check if the node is mutable buffer according to the provided graph signature.
67+ If it is one return its fqn as well
68+ """
69+ if node .op == "placeholder" :
70+ if isinstance (node .target , str ):
71+ if node .target in graph_signature .inputs_to_buffers :
72+ fqn = graph_signature .inputs_to_buffers [node .target ]
73+ # if the buffer is mutated then record that
74+ if fqn in graph_signature .buffers_to_mutate .values ():
75+ return True , fqn
76+ return False , None
77+
78+
79+ def _get_spec_from_node (node : fx .Node ) -> TensorSpec :
80+ specs = get_node_tensor_specs (node )
81+ return specs [0 ]
82+
83+
84+ def _insert_mutable_buffer_specs (
85+ state : "_MemoryPlanningState" , gm : torch .fx .GraphModule , gs : ExportGraphSignature
86+ ):
87+ for node in gm .graph .nodes :
88+ is_mutable , fqn = _is_mutable_buffer (node , gs )
89+ if is_mutable :
90+ assert fqn
91+ spec = _get_spec_from_node (node )
92+ if (
93+ getattr (spec , "mem_id" , None ) is not None
94+ or getattr (spec , "mem_offset" , None ) is not None
95+ ):
96+ raise ValueError (
97+ "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
98+ )
99+ if fqn not in state .mutable_buffers .keys ():
100+ state .mutable_buffers [fqn ] = set ()
101+ state .mutable_buffers [fqn ].add (spec )
102+ continue
103+ is_buffer , fqn = _is_buffer (node , gs )
104+ # If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
105+ # So cache it and later double check that this buffer never appears mutable
106+ if is_buffer :
107+ assert fqn
108+ spec = _get_spec_from_node (node )
109+ if (
110+ getattr (spec , "mem_id" , None ) is not None
111+ or getattr (spec , "mem_offset" , None ) is not None
112+ ):
113+ raise ValueError (
114+ "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
115+ )
116+ if fqn not in state .maybe_mutable_buffers .keys ():
117+ state .maybe_mutable_buffers [fqn ] = set ()
118+ state .maybe_mutable_buffers [fqn ].add (spec )
119+
120+
121+ def _check_default_mem_ids (gm : torch .fx .GraphModule ):
122+ for node in gm .graph .nodes :
123+ for spec in collect_specs_from_nodes (
124+ filter_nodes (itertools .chain ([node ], node .args , node .kwargs .values ())),
125+ None ,
126+ ignore_graph_input = False ,
127+ ignore_const = False ,
128+ ignore_out_var_node = False ,
129+ dedup = False ,
130+ do_assertion = False ,
131+ ignore_dynamic_unbound_tensor = False ,
132+ ):
133+ mem_id = getattr (spec , "mem_id" , None )
134+ if mem_id is not None and mem_id != 1 :
135+ raise ValueError (
136+ "Cannot share mutable buffers if all other tensors are not on the default mem_id of 1"
137+ )
138+
139+
140+ @dataclass
141+ class _MemoryPlanningState :
142+ mutable_buffers : Dict [str , Set [TensorSpec ]] = field (default_factory = dict )
143+ maybe_mutable_buffers : Dict [str , Set [TensorSpec ]] = field (default_factory = dict )
144+ graph_modules : List [torch .fx .GraphModule ] = field (default_factory = list )
145+
146+
40147class MemoryPlanningPass (PassBase ):
41148 def __init__ (
42149 self ,
@@ -45,6 +152,7 @@ def __init__(
45152 alloc_graph_input : bool = True ,
46153 alloc_graph_output : bool = True ,
47154 alloc_mutable_buffers : bool = True ,
155+ share_mutable_buffers : bool = False ,
48156 alignment : int = ALIGNMENT ,
49157 ) -> None :
50158 r"""
@@ -55,12 +163,18 @@ def __init__(
55163 """
56164 if memory_planning_algo is None :
57165 memory_planning_algo = MemoryPlanningAlgorithmSuite ()
166+ if share_mutable_buffers and not alloc_mutable_buffers :
167+ raise ValueError (
168+ "share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
169+ )
58170 self .memory_planning_algo : Callable [..., List [int ]] = memory_planning_algo
59171 self .allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
60172 self .alloc_graph_input = alloc_graph_input
61173 self .alloc_graph_output = alloc_graph_output
62174 self .alloc_mutable_buffers = alloc_mutable_buffers
175+ self .share_mutable_buffers = share_mutable_buffers
63176 self .alignment = alignment
177+ self .state = _MemoryPlanningState ()
64178
65179 def _set_alloc_node_spec (self , graph_module : torch .fx .GraphModule ) -> None :
66180 """
@@ -134,9 +248,17 @@ def run(
134248 graph_signature ,
135249 self .alloc_graph_input ,
136250 self .alloc_graph_output ,
137- self .alloc_mutable_buffers ,
251+ # If we are sharing the mutable buffers then do not allocate them in
252+ # memory planning algo, instead collect all of the specs over all the entry
253+ # points and then allocate them directly in the run_multimethod name call
254+ self .alloc_mutable_buffers and not self .share_mutable_buffers ,
138255 )
139256
257+ if self .share_mutable_buffers and graph_signature is not None :
258+ self .state .graph_modules .append (graph_module )
259+ _check_default_mem_ids (graph_module )
260+ _insert_mutable_buffer_specs (self .state , graph_module , graph_signature )
261+
140262 # TODO: make the verifier do the work recursively to handle
141263 # control flow
142264 verifier = Verifier (
@@ -164,3 +286,31 @@ def run(
164286 # I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
165287 verifier .verify_storage_reuse ()
166288 return PassResult (graph_module , True )
289+
290+ def run_multimethod (self ):
291+ "Resolve any memory planning done across entry points"
292+ if self .share_mutable_buffers :
293+ arena : int = 0
294+
295+ # Every spec that shares an fqn is the same tensor! So we give it the same id and offset
296+ # anywhere it appears.
297+ for fqn , specs_set in self .state .mutable_buffers .items ():
298+ specs = list (specs_set )
299+ # If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
300+ if fqn in self .state .maybe_mutable_buffers .keys ():
301+ specs .extend (self .state .maybe_mutable_buffers [fqn ])
302+ for spec in specs :
303+ # Assume a default memory planning placed all activations on 1, place shared state on 2.
304+ spec .mem_id = 2
305+ spec .realign (self .alignment )
306+ # State is persistent, so the memory never overlaps.
307+ spec .mem_offset = arena
308+ # They should all be the same size since they are the same tensor, so just bump off the first.
309+ arena += specs [0 ].allocated_memory
310+
311+ for graph_module in self .state .graph_modules :
312+ if len (graph_module .meta ["non_const_buffer_sizes" ]) != 2 :
313+ raise ValueError (
314+ "Cannot share mutable state if not using default memory ids"
315+ )
316+ graph_module .meta ["non_const_buffer_sizes" ].append (arena )
0 commit comments