77from typing import Dict , List , Optional , Tuple
88
99import torch
10+ from executorch .exir .operator .convert import is_inplace_variant
1011
1112from torch .export .exported_program import (
1213 ExportedProgram ,
1718)
1819from torch .export .graph_signature import TensorArgument
1920from torch .utils import _pytree as pytree
21+ from torchgen .model import SchemaKind
2022
2123
2224def _insert_copy (
@@ -70,6 +72,44 @@ def _insert_copy(
7072 return buffer_output_nodes
7173
7274
75+ def _is_inplace_node (node : torch .fx .Node ) -> bool :
76+ """Check if a node is an inplace node."""
77+ return (
78+ node .op == "call_function"
79+ and isinstance (node .target , torch ._ops .OpOverload )
80+ and is_inplace_variant (
81+ node .target ._schema .name , node .target ._schema .overload_name
82+ )
83+ )
84+
85+
86+ def _inplace_lineage (
87+ output_arg : torch .fx .Node ,
88+ gm : torch .fx .GraphModule ,
89+ gs : ExportGraphSignature ,
90+ kind : SchemaKind ,
91+ ) -> bool :
92+ """
93+ Walk the graph backwards to see if output_arg is ultimately the same as an input.
94+ """
95+ if kind != OutputKind .BUFFER_MUTATION and kind != OutputKind .USER_INPUT_MUTATION :
96+ return False
97+
98+ while output_arg .op != "placeholder" :
99+ if _is_inplace_node (output_arg ):
100+ # From looking at native_functions.yaml, inplace ops always have self as the first arg
101+ output_arg = output_arg .args [0 ] # pyre-ignore
102+ else :
103+ return False
104+
105+ # If the output arg was a buffer then it needs to reach a buffer placeholder
106+ if kind == OutputKind .BUFFER_MUTATION :
107+ return output_arg .target in gs .inputs_to_buffers
108+ # If the output arg was a user input then it needs to reach a user input placeholder
109+ assert kind == OutputKind .USER_INPUT_MUTATION
110+ return output_arg .target in gs .user_inputs
111+
112+
73113def insert_write_back_for_buffers_pass (
74114 ep : ExportedProgram ,
75115) -> Tuple [torch .fx .GraphModule , ExportGraphSignature ]:
@@ -99,9 +139,15 @@ def insert_write_back_for_buffers_pass(
99139 if lifted_node is not None :
100140 input_name_to_node [lifted_node ] = input_node
101141
142+ output_node = None
143+ for node in gm .graph .nodes :
144+ if node .op == "output" :
145+ output_node = node
146+ break
147+
102148 # Grab the mutable buffer nodes in the outputs,
103149 mutated_outputs : List [Optional [str ]] = []
104- for out_spec in ep .graph_signature .output_specs :
150+ for i , out_spec in enumerate ( ep .graph_signature .output_specs ) :
105151 # if the output arg is the input value then all operations on it are in-place
106152 # so there's no need to add a copy_ node
107153 if (
@@ -112,7 +158,12 @@ def insert_write_back_for_buffers_pass(
112158 out_spec .target in input_name_to_node
113159 and
114160 # if the arg and target are not the same, we add a copy_ node.
115- out_spec .arg .name != input_name_to_node [out_spec .target ].name
161+ not _inplace_lineage (
162+ output_node .args [0 ][i ],
163+ gm ,
164+ ep .graph_signature ,
165+ ep .graph_signature .output_specs [i ].kind ,
166+ )
116167 ):
117168 mutated_outputs .append (out_spec .target )
118169 else :
0 commit comments