1515 OutputKind ,
1616 OutputSpec ,
1717)
18+ from executorch .exir .operator .convert import (
19+ _get_overload_schema ,
20+ )
1821from torch .export .graph_signature import TensorArgument
1922from torch .utils import _pytree as pytree
23+ from torchgen .model import SchemaKind
2024
2125
2226def _insert_copy (
@@ -69,6 +73,28 @@ def _insert_copy(
6973 gm .graph .erase_node (output_node )
7074 return buffer_output_nodes
7175
76+ def _inplace_lineage (output_arg : torch .fx .Node , gm : torch .fx .GraphModule , gs : ExportGraphSignature , kind : SchemaKind ) -> bool :
77+ """
78+ Walk the graph backwards to see if output_arg is ultimately the same as an input.
79+ """
80+ if kind != OutputKind .BUFFER_MUTATION and kind != OutputKind .USER_INPUT_MUTATION :
81+ return False
82+
83+ while output_arg .op != "placeholder" :
84+ if output_arg .op == "call_function" and _get_overload_schema (output_arg .target ).kind () == SchemaKind .inplace : #pyre-ignore
85+ # From looking at native_functions.yaml, inplace ops always have self as the first arg
86+ output_arg = output_arg .args [0 ] #pyre-ignore
87+ else :
88+ return False
89+
90+ # If the output arg was a buffer then it needs to reach a buffer placeholder
91+ if kind == OutputKind .BUFFER_MUTATION :
92+ assert output_arg .target in gs .inputs_to_buffers
93+ return True
94+ # If the output arg was a user input then it needs to reach a user input placeholder
95+ assert output_arg .target in gs .user_inputs
96+ return True
97+
7298
7399def insert_write_back_for_buffers_pass (
74100 ep : ExportedProgram ,
@@ -99,9 +125,16 @@ def insert_write_back_for_buffers_pass(
99125 if lifted_node is not None :
100126 input_name_to_node [lifted_node ] = input_node
101127
128+
129+ output_node = None
130+ for node in gm .graph .nodes :
131+ if node .op == "output" :
132+ output_node = node
133+ break
134+
102135 # Grab the mutable buffer nodes in the outputs,
103136 mutated_outputs : List [Optional [str ]] = []
104- for out_spec in ep .graph_signature .output_specs :
137+ for i , out_spec in enumerate ( ep .graph_signature .output_specs ) :
105138 # if the output arg is the input value then all operations on it are in-place
106139 # so there's no need to add a copy_ node
107140 if (
@@ -112,7 +145,7 @@ def insert_write_back_for_buffers_pass(
112145 out_spec .target in input_name_to_node
113146 and
114147 # if the arg and target are not the same, we add a copy_ node.
115- out_spec . arg . name != input_name_to_node [ out_spec . target ]. name
148+ not _inplace_lineage ( output_node . args [ 0 ][ i ], gm , ep . graph_signature , ep . graph_signature . output_specs [ i ]. kind )
116149 ):
117150 mutated_outputs .append (out_spec .target )
118151 else :
0 commit comments