77from typing import Dict , List , Optional , Tuple
88
99import torch
10+ from executorch .exir .operator .convert import _get_overload_schema
1011
1112from torch .export .exported_program import (
1213 ExportedProgram ,
1718)
1819from torch .export .graph_signature import TensorArgument
1920from torch .utils import _pytree as pytree
21+ from torchgen .model import SchemaKind
2022
2123
2224def _insert_copy (
@@ -70,6 +72,38 @@ def _insert_copy(
7072 return buffer_output_nodes
7173
7274
75+ def _inplace_lineage (
76+ output_arg : torch .fx .Node ,
77+ gm : torch .fx .GraphModule ,
78+ gs : ExportGraphSignature ,
79+ kind : SchemaKind ,
80+ ) -> bool :
81+ """
82+ Walk the graph backwards to see if output_arg is ultimately the same as an input.
83+ """
84+ if kind != OutputKind .BUFFER_MUTATION and kind != OutputKind .USER_INPUT_MUTATION :
85+ return False
86+
87+ while output_arg .op != "placeholder" :
88+ if (
89+ output_arg .op == "call_function"
90+ and _get_overload_schema (output_arg .target ).kind () # pyre-ignore
91+ == SchemaKind .inplace
92+ ):
93+ # From looking at native_functions.yaml, inplace ops always have self as the first arg
94+ output_arg = output_arg .args [0 ] # pyre-ignore
95+ else :
96+ return False
97+
98+ # If the output arg was a buffer then it needs to reach a buffer placeholder
99+ if kind == OutputKind .BUFFER_MUTATION :
100+ assert output_arg .target in gs .inputs_to_buffers
101+ return True
102+ # If the output arg was a user input then it needs to reach a user input placeholder
103+ assert output_arg .target in gs .user_inputs
104+ return True
105+
106+
73107def insert_write_back_for_buffers_pass (
74108 ep : ExportedProgram ,
75109) -> Tuple [torch .fx .GraphModule , ExportGraphSignature ]:
@@ -99,9 +133,15 @@ def insert_write_back_for_buffers_pass(
99133 if lifted_node is not None :
100134 input_name_to_node [lifted_node ] = input_node
101135
136+ output_node = None
137+ for node in gm .graph .nodes :
138+ if node .op == "output" :
139+ output_node = node
140+ break
141+
102142 # Grab the mutable buffer nodes in the outputs,
103143 mutated_outputs : List [Optional [str ]] = []
104- for out_spec in ep .graph_signature .output_specs :
144+ for i , out_spec in enumerate ( ep .graph_signature .output_specs ) :
105145 # if the output arg is the input value then all operations on it are in-place
106146 # so there's no need to add a copy_ node
107147 if (
@@ -112,7 +152,12 @@ def insert_write_back_for_buffers_pass(
112152 out_spec .target in input_name_to_node
113153 and
114154 # if the arg and target are not the same, we add a copy_ node.
115- out_spec .arg .name != input_name_to_node [out_spec .target ].name
155+ not _inplace_lineage (
156+ output_node .args [0 ][i ],
157+ gm ,
158+ ep .graph_signature ,
159+ ep .graph_signature .output_specs [i ].kind ,
160+ )
116161 ):
117162 mutated_outputs .append (out_spec .target )
118163 else :
0 commit comments