44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import itertools
78import logging
89import warnings
10+ from dataclasses import dataclass , field
911from functools import partial
10- from typing import Any , Callable , List , Optional
12+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple
1113
1214import torch
1315from executorch .exir ._warnings import deprecated
1618from executorch .exir .memory_planning import (
1719 _is_out_var_node ,
1820 apply_algo ,
21+ collect_specs_from_nodes ,
22+ filter_nodes ,
23+ get_node_tensor_specs ,
1924 get_node_tensor_specs ,
2025 MemoryPlanningAlgorithmSuite ,
26+ naive ,
2127 Verifier ,
2228)
2329from executorch .exir .operator .convert import get_out_args_from_opoverload
2430from executorch .exir .pass_base import PassBase , PassResult
25- from executorch .exir .tensor import ALIGNMENT
31+ from executorch .exir .tensor import ALIGNMENT , TensorSpec
32+ from torch import fx
2633from torch .export .exported_program import ExportGraphSignature
34+ from torch .fx import Node
2735
2836
2937# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
@@ -36,6 +44,84 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
3644 except AttributeError :
3745 return str (any_callable )
3846
47+ def _is_buffer (node : Node , graph_signature : ExportGraphSignature ) -> Tuple [bool , Optional [str ]]:
48+ """
49+ Check if the node is buffer according to the provided graph signature.
50+ If it is one return its fqn as well
51+ """
52+ if node .op == "placeholder" :
53+ if isinstance (node .target , str ):
54+ if node .target in graph_signature .inputs_to_buffers :
55+ fqn = graph_signature .inputs_to_buffers [node .target ]
56+ return True , fqn
57+ return False , None
58+
59+ def _is_mutable_buffer (
60+ node : Node , graph_signature : ExportGraphSignature
61+ ) -> Tuple [bool , Optional [str ]]:
62+ """
63+ Check if the node is mutable buffer according to the provided graph signature.
64+ If it is one return its fqn as well
65+ """
66+ if node .op == "placeholder" :
67+ if isinstance (node .target , str ):
68+ if node .target in graph_signature .inputs_to_buffers :
69+ fqn = graph_signature .inputs_to_buffers [node .target ]
70+ # if the buffer is mutated then record that
71+ if fqn in graph_signature .buffers_to_mutate .values ():
72+ return True , fqn
73+ return False , None
74+
75+ def _get_spec_from_node (node : fx .Node ) -> TensorSpec :
76+ specs = get_node_tensor_specs (node )
77+ assert (len (specs ) == 1 )
78+ return specs [0 ]
79+
80+ def _insert_mutable_buffer_specs (state : "_MemoryPlanningState" , gm : torch .fx .GraphModule , gs : ExportGraphSignature ):
81+ for node in gm .graph .nodes :
82+ is_mutable , fqn = _is_mutable_buffer (node , gs )
83+ if is_mutable :
84+ assert (fqn )
85+ spec = _get_spec_from_node (node )
86+ if getattr (spec , 'mem_id' , None ) is not None or getattr (spec , 'mem_offset' , None ) is not None :
87+ raise ValueError ("Cannot share mutable buffers if they already have a mem_id or mem_offset assigned" )
88+ if fqn not in state .mutable_buffers .keys ():
89+ state .mutable_buffers [fqn ] = set ()
90+ state .mutable_buffers [fqn ].add (spec )
91+ continue
92+ is_buffer , fqn = _is_buffer (node , gs )
93+ # If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
94+ # So cache it and later double check that this buffer never appears mutable
95+ if is_buffer :
96+ assert (fqn )
97+ spec = _get_spec_from_node (node )
98+ if getattr (spec , 'mem_id' , None ) is not None or getattr (spec , 'mem_offset' , None ) is not None :
99+ raise ValueError ("Cannot share mutable buffers if they already have a mem_id or mem_offset assigned" )
100+ if fqn not in state .maybe_mutable_buffers .keys ():
101+ state .maybe_mutable_buffers [fqn ] = set ()
102+ state .maybe_mutable_buffers [fqn ].add (spec )
103+
104+ def _check_default_mem_ids (gm : torch .fx .GraphModule ):
105+ for node in gm .graph .nodes :
106+ for spec in collect_specs_from_nodes (
107+ filter_nodes (itertools .chain ([node ], node .args , node .kwargs .values ())),
108+ None ,
109+ ignore_graph_input = False ,
110+ ignore_const = False ,
111+ ignore_out_var_node = False ,
112+ dedup = False ,
113+ do_assertion = False ,
114+ ignore_dynamic_unbound_tensor = False ,
115+ ):
116+ mem_id = getattr (spec , 'mem_id' , None )
117+ if mem_id is not None and mem_id != 1 :
118+ raise ValueError ("Cannot share mutable buffers if all other tensors are not on the default mem_id of 1" )
119+
120+ @dataclass
121+ class _MemoryPlanningState :
122+ mutable_buffers : Dict [str , Set [TensorSpec ]] = field (default_factory = dict )
123+ maybe_mutable_buffers : Dict [str , Set [TensorSpec ]] = field (default_factory = dict )
124+ graph_modules : List [torch .fx .GraphModule ] = field (default_factory = list )
39125
40126class MemoryPlanningPass (PassBase ):
41127 def __init__ (
@@ -45,6 +131,7 @@ def __init__(
45131 alloc_graph_input : bool = True ,
46132 alloc_graph_output : bool = True ,
47133 alloc_mutable_buffers : bool = True ,
134+ share_mutable_buffers : bool = False ,
48135 alignment : int = ALIGNMENT ,
49136 ) -> None :
50137 r"""
@@ -55,12 +142,18 @@ def __init__(
55142 """
56143 if memory_planning_algo is None :
57144 memory_planning_algo = MemoryPlanningAlgorithmSuite ()
145+ if share_mutable_buffers and not alloc_mutable_buffers :
146+ raise ValueError (
147+ "share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
148+ )
58149 self .memory_planning_algo : Callable [..., List [int ]] = memory_planning_algo
59150 self .allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
60151 self .alloc_graph_input = alloc_graph_input
61152 self .alloc_graph_output = alloc_graph_output
62153 self .alloc_mutable_buffers = alloc_mutable_buffers
154+ self .share_mutable_buffers = share_mutable_buffers
63155 self .alignment = alignment
156+ self .state = _MemoryPlanningState ()
64157
65158 def _set_alloc_node_spec (self , graph_module : torch .fx .GraphModule ) -> None :
66159 """
@@ -134,9 +227,17 @@ def run(
134227 graph_signature ,
135228 self .alloc_graph_input ,
136229 self .alloc_graph_output ,
137- self .alloc_mutable_buffers ,
230+ # If we are sharing the mutable buffers then do not allocate them in
231+ # memory planning algo, instead collect all of the specs over all the entry
232+ # points and then allocate them directly in the run_multimethod name call
233+ self .alloc_mutable_buffers and not self .share_mutable_buffers ,
138234 )
139235
236+ if self .share_mutable_buffers and graph_signature is not None :
237+ self .state .graph_modules .append (graph_module )
238+ _check_default_mem_ids (graph_module )
239+ _insert_mutable_buffer_specs (self .state , graph_module , graph_signature )
240+
140241 # TODO: make the verifier do the work recursively to handle
141242 # control flow
142243 verifier = Verifier (
@@ -164,3 +265,31 @@ def run(
164265 # I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
165266 verifier .verify_storage_reuse ()
166267 return PassResult (graph_module , True )
268+
269+ def run_multimethod (self ):
270+ "Resolve any memory planning done across entry points"
271+ if self .share_mutable_buffers :
272+ arena : int = 0
273+
274+ # Every spec that shares an fqn is the same tensor! So we give it the same id and offset
275+ # anywhere it appears.
276+ for fqn , specs_set in self .state .mutable_buffers .items ():
277+ specs = list (specs_set )
278+ # If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
279+ if fqn in self .state .maybe_mutable_buffers .keys ():
280+ specs .extend (self .state .maybe_mutable_buffers [fqn ])
281+ for spec in specs :
282+ # Assume a default memory planning placed all activations on 1, place shared state on 2.
283+ spec .mem_id = 2
284+ spec .realign (self .alignment )
285+ # State is persistent, so the memory never overlaps.
286+ spec .mem_offset = arena
287+ # They should all be the same size since they are the same tensor, so just bump off the first.
288+ arena += specs [0 ].allocated_memory
289+
290+ for graph_module in self .state .graph_modules :
291+ if len (graph_module .meta ['non_const_buffer_sizes' ]) != 2 :
292+ raise ValueError ("Cannot share mutable state if not using default memory ids" )
293+ graph_module .meta ['non_const_buffer_sizes' ].append (arena )
294+
295+
0 commit comments