Skip to content

Commit c1b32f9

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
insert write back now tries to avoid inserting if all operations were inplace (pytorch#12008)
Summary: Before only considered if the placeholder was the output. Now allow if all operations are inplace. Reviewed By: angelayi Differential Revision: D77204116
1 parent d4cc258 commit c1b32f9

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, List, Optional, Tuple
88

99
import torch
10+
from executorch.exir.operator.convert import _get_overload_schema
1011

1112
from torch.export.exported_program import (
1213
ExportedProgram,
@@ -17,6 +18,7 @@
1718
)
1819
from torch.export.graph_signature import TensorArgument
1920
from torch.utils import _pytree as pytree
21+
from torchgen.model import SchemaKind
2022

2123

2224
def _insert_copy(
@@ -70,6 +72,38 @@ def _insert_copy(
7072
return buffer_output_nodes
7173

7274

75+
def _inplace_lineage(
76+
output_arg: torch.fx.Node,
77+
gm: torch.fx.GraphModule,
78+
gs: ExportGraphSignature,
79+
kind: SchemaKind,
80+
) -> bool:
81+
"""
82+
Walk the graph backwards to see if output_arg is ultimately the same as an input.
83+
"""
84+
if kind != OutputKind.BUFFER_MUTATION and kind != OutputKind.USER_INPUT_MUTATION:
85+
return False
86+
87+
while output_arg.op != "placeholder":
88+
if (
89+
output_arg.op == "call_function"
90+
and _get_overload_schema(output_arg.target).kind() # pyre-ignore
91+
== SchemaKind.inplace
92+
):
93+
# From looking at native_functions.yaml, inplace ops always have self as the first arg
94+
output_arg = output_arg.args[0] # pyre-ignore
95+
else:
96+
return False
97+
98+
# If the output arg was a buffer then it needs to reach a buffer placeholder
99+
if kind == OutputKind.BUFFER_MUTATION:
100+
assert output_arg.target in gs.inputs_to_buffers
101+
return True
102+
# If the output arg was a user input then it needs to reach a user input placeholder
103+
assert output_arg.target in gs.user_inputs
104+
return True
105+
106+
73107
def insert_write_back_for_buffers_pass(
74108
ep: ExportedProgram,
75109
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
@@ -99,9 +133,15 @@ def insert_write_back_for_buffers_pass(
99133
if lifted_node is not None:
100134
input_name_to_node[lifted_node] = input_node
101135

136+
output_node = None
137+
for node in gm.graph.nodes:
138+
if node.op == "output":
139+
output_node = node
140+
break
141+
102142
# Grab the mutable buffer nodes in the outputs,
103143
mutated_outputs: List[Optional[str]] = []
104-
for out_spec in ep.graph_signature.output_specs:
144+
for i, out_spec in enumerate(ep.graph_signature.output_specs):
105145
# if the output arg is the input value then all operations on it are in-place
106146
# so there's no need to add a copy_ node
107147
if (
@@ -112,7 +152,12 @@ def insert_write_back_for_buffers_pass(
112152
out_spec.target in input_name_to_node
113153
and
114154
# if the arg and target are not the same, we add a copy_ node.
115-
out_spec.arg.name != input_name_to_node[out_spec.target].name
155+
not _inplace_lineage(
156+
output_node.args[0][i],
157+
gm,
158+
ep.graph_signature,
159+
ep.graph_signature.output_specs[i].kind,
160+
)
116161
):
117162
mutated_outputs.append(out_spec.target)
118163
else:

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ python_unittest(
145145
"//caffe2:torch",
146146
"//executorch/exir:lib",
147147
"//executorch/exir/passes:lib",
148+
"//executorch/extension/pybindings:portable_lib",
148149
],
149150
)
150151

0 commit comments

Comments
 (0)