Skip to content

Commit 69f5bf4

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
insert write back now tries to avoid inserting if all operations were inplace (#12008)
Summary: Before only considered if the placeholder was the output. Now allow if all operations are inplace. Differential Revision: D77204116
1 parent d4cc258 commit 69f5bf4

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
OutputKind,
1616
OutputSpec,
1717
)
18+
from executorch.exir.operator.convert import (
19+
_get_overload_schema,
20+
)
1821
from torch.export.graph_signature import TensorArgument
1922
from torch.utils import _pytree as pytree
23+
from torchgen.model import SchemaKind
2024

2125

2226
def _insert_copy(
@@ -69,6 +73,28 @@ def _insert_copy(
6973
gm.graph.erase_node(output_node)
7074
return buffer_output_nodes
7175

76+
def _inplace_lineage(output_arg: torch.fx.Node, gm: torch.fx.GraphModule, gs: ExportGraphSignature, kind: SchemaKind) -> bool:
77+
"""
78+
Walk the graph backwards to see if output_arg is ultimately the same as an input.
79+
"""
80+
if kind != OutputKind.BUFFER_MUTATION and kind != OutputKind.USER_INPUT_MUTATION:
81+
return False
82+
83+
while output_arg.op != "placeholder":
84+
if output_arg.op == "call_function" and _get_overload_schema(output_arg.target).kind() == SchemaKind.inplace: #pyre-ignore
85+
# From looking at native_functions.yaml, inplace ops always have self as the first arg
86+
output_arg = output_arg.args[0] #pyre-ignore
87+
else:
88+
return False
89+
90+
# If the output arg was a buffer then it needs to reach a buffer placeholder
91+
if kind == OutputKind.BUFFER_MUTATION:
92+
assert output_arg.target in gs.inputs_to_buffers
93+
return True
94+
# If the output arg was a user input then it needs to reach a user input placeholder
95+
assert output_arg.target in gs.user_inputs
96+
return True
97+
7298

7399
def insert_write_back_for_buffers_pass(
74100
ep: ExportedProgram,
@@ -99,9 +125,16 @@ def insert_write_back_for_buffers_pass(
99125
if lifted_node is not None:
100126
input_name_to_node[lifted_node] = input_node
101127

128+
129+
output_node = None
130+
for node in gm.graph.nodes:
131+
if node.op == "output":
132+
output_node = node
133+
break
134+
102135
# Grab the mutable buffer nodes in the outputs,
103136
mutated_outputs: List[Optional[str]] = []
104-
for out_spec in ep.graph_signature.output_specs:
137+
for i, out_spec in enumerate(ep.graph_signature.output_specs):
105138
# if the output arg is the input value then all operations on it are in-place
106139
# so there's no need to add a copy_ node
107140
if (
@@ -112,7 +145,7 @@ def insert_write_back_for_buffers_pass(
112145
out_spec.target in input_name_to_node
113146
and
114147
# if the arg and target are not the same, we add a copy_ node.
115-
out_spec.arg.name != input_name_to_node[out_spec.target].name
148+
not _inplace_lineage(output_node.args[0][i], gm, ep.graph_signature, ep.graph_signature.output_specs[i].kind)
116149
):
117150
mutated_outputs.append(out_spec.target)
118151
else:

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ python_unittest(
145145
"//caffe2:torch",
146146
"//executorch/exir:lib",
147147
"//executorch/exir/passes:lib",
148+
"//executorch/extension/pybindings:portable_lib",
148149
],
149150
)
150151

0 commit comments

Comments
 (0)