Skip to content

Commit db8bec5

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
insert write back now tries to avoid inserting if all operations were inplace (#12008)
Summary: Before only considered if the placeholder was the output. Now allow if all operations are inplace. Reviewed By: angelayi Differential Revision: D77204116
1 parent d4cc258 commit db8bec5

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, List, Optional, Tuple
88

99
import torch
10+
from executorch.exir.operator.convert import is_inplace_variant
1011

1112
from torch.export.exported_program import (
1213
ExportedProgram,
@@ -17,6 +18,7 @@
1718
)
1819
from torch.export.graph_signature import TensorArgument
1920
from torch.utils import _pytree as pytree
21+
from torchgen.model import SchemaKind
2022

2123

2224
def _insert_copy(
@@ -70,6 +72,45 @@ def _insert_copy(
7072
return buffer_output_nodes
7173

7274

75+
def _is_inplace_node(node: torch.fx.Node) -> bool:
76+
"""Check if a node is an inplace node."""
77+
return (
78+
node.op == "call_function"
79+
and isinstance(node.target, torch._ops.OpOverload)
80+
and is_inplace_variant(
81+
node.target._schema.name, node.target._schema.overload_name
82+
)
83+
)
84+
85+
86+
def _inplace_lineage(
87+
output_arg: torch.fx.Node,
88+
gm: torch.fx.GraphModule,
89+
gs: ExportGraphSignature,
90+
kind: SchemaKind,
91+
) -> bool:
92+
"""
93+
Walk the graph backwards to see if output_arg is ultimately the same as an input.
94+
"""
95+
if kind != OutputKind.BUFFER_MUTATION and kind != OutputKind.USER_INPUT_MUTATION:
96+
return False
97+
98+
while output_arg.op != "placeholder":
99+
if _is_inplace_node(output_arg):
100+
# From looking at native_functions.yaml, inplace ops always have self as the first arg
101+
output_arg = output_arg.args[0] # pyre-ignore
102+
else:
103+
return False
104+
105+
# If the output arg was a buffer then it needs to reach a buffer placeholder
106+
if kind == OutputKind.BUFFER_MUTATION:
107+
assert output_arg.target in gs.inputs_to_buffers
108+
return True
109+
# If the output arg was a user input then it needs to reach a user input placeholder
110+
assert output_arg.target in gs.user_inputs
111+
return True
112+
113+
73114
def insert_write_back_for_buffers_pass(
74115
ep: ExportedProgram,
75116
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
@@ -99,9 +140,15 @@ def insert_write_back_for_buffers_pass(
99140
if lifted_node is not None:
100141
input_name_to_node[lifted_node] = input_node
101142

143+
output_node = None
144+
for node in gm.graph.nodes:
145+
if node.op == "output":
146+
output_node = node
147+
break
148+
102149
# Grab the mutable buffer nodes in the outputs,
103150
mutated_outputs: List[Optional[str]] = []
104-
for out_spec in ep.graph_signature.output_specs:
151+
for i, out_spec in enumerate(ep.graph_signature.output_specs):
105152
# if the output arg is the input value then all operations on it are in-place
106153
# so there's no need to add a copy_ node
107154
if (
@@ -112,7 +159,12 @@ def insert_write_back_for_buffers_pass(
112159
out_spec.target in input_name_to_node
113160
and
114161
# if the arg and target are not the same, we add a copy_ node.
115-
out_spec.arg.name != input_name_to_node[out_spec.target].name
162+
not _inplace_lineage(
163+
output_node.args[0][i],
164+
gm,
165+
ep.graph_signature,
166+
ep.graph_signature.output_specs[i].kind,
167+
)
116168
):
117169
mutated_outputs.append(out_spec.target)
118170
else:

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ python_unittest(
145145
"//caffe2:torch",
146146
"//executorch/exir:lib",
147147
"//executorch/exir/passes:lib",
148+
"//executorch/extension/pybindings:portable_lib",
148149
],
149150
)
150151

0 commit comments

Comments
 (0)