7
7
from typing import Dict , List , Optional , Tuple
8
8
9
9
import torch
10
+ from executorch .exir .operator .convert import is_inplace_variant
10
11
11
12
from torch .export .exported_program import (
12
13
ExportedProgram ,
17
18
)
18
19
from torch .export .graph_signature import TensorArgument
19
20
from torch .utils import _pytree as pytree
21
+ from torchgen .model import SchemaKind
20
22
21
23
22
24
def _insert_copy (
@@ -70,6 +72,44 @@ def _insert_copy(
70
72
return buffer_output_nodes
71
73
72
74
75
+ def _is_inplace_node (node : torch .fx .Node ) -> bool :
76
+ """Check if a node is an inplace node."""
77
+ return (
78
+ node .op == "call_function"
79
+ and isinstance (node .target , torch ._ops .OpOverload )
80
+ and is_inplace_variant (
81
+ node .target ._schema .name , node .target ._schema .overload_name
82
+ )
83
+ )
84
+
85
+
86
+ def _inplace_lineage (
87
+ output_arg : torch .fx .Node ,
88
+ gm : torch .fx .GraphModule ,
89
+ gs : ExportGraphSignature ,
90
+ kind : SchemaKind ,
91
+ ) -> bool :
92
+ """
93
+ Walk the graph backwards to see if output_arg is ultimately the same as an input.
94
+ """
95
+ if kind != OutputKind .BUFFER_MUTATION and kind != OutputKind .USER_INPUT_MUTATION :
96
+ return False
97
+
98
+ while output_arg .op != "placeholder" :
99
+ if _is_inplace_node (output_arg ):
100
+ # From looking at native_functions.yaml, inplace ops always have self as the first arg
101
+ output_arg = output_arg .args [0 ] # pyre-ignore
102
+ else :
103
+ return False
104
+
105
+ # If the output arg was a buffer then it needs to reach a buffer placeholder
106
+ if kind == OutputKind .BUFFER_MUTATION :
107
+ return output_arg .target in gs .inputs_to_buffers
108
+ # If the output arg was a user input then it needs to reach a user input placeholder
109
+ assert kind == OutputKind .USER_INPUT_MUTATION
110
+ return output_arg .target in gs .user_inputs
111
+
112
+
73
113
def insert_write_back_for_buffers_pass (
74
114
ep : ExportedProgram ,
75
115
) -> Tuple [torch .fx .GraphModule , ExportGraphSignature ]:
@@ -99,9 +139,15 @@ def insert_write_back_for_buffers_pass(
99
139
if lifted_node is not None :
100
140
input_name_to_node [lifted_node ] = input_node
101
141
142
+ output_node = None
143
+ for node in gm .graph .nodes :
144
+ if node .op == "output" :
145
+ output_node = node
146
+ break
147
+
102
148
# Grab the mutable buffer nodes in the outputs,
103
149
mutated_outputs : List [Optional [str ]] = []
104
- for out_spec in ep .graph_signature .output_specs :
150
+ for i , out_spec in enumerate ( ep .graph_signature .output_specs ) :
105
151
# if the output arg is the input value then all operations on it are in-place
106
152
# so there's no need to add a copy_ node
107
153
if (
@@ -112,7 +158,12 @@ def insert_write_back_for_buffers_pass(
112
158
out_spec .target in input_name_to_node
113
159
and
114
160
# if the arg and target are not the same, we add a copy_ node.
115
- out_spec .arg .name != input_name_to_node [out_spec .target ].name
161
+ not _inplace_lineage (
162
+ output_node .args [0 ][i ],
163
+ gm ,
164
+ ep .graph_signature ,
165
+ ep .graph_signature .output_specs [i ].kind ,
166
+ )
116
167
):
117
168
mutated_outputs .append (out_spec .target )
118
169
else :
0 commit comments