Skip to content

Commit ad73cbd

Browse files
insert write back now tries to avoid inserting if all operations were inplace
Differential Revision: D77204116 Pull Request resolved: #12008
1 parent b11075f commit ad73cbd

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, List, Optional, Tuple
88

99
import torch
10+
from executorch.exir.operator.convert import is_inplace_variant
1011

1112
from torch.export.exported_program import (
1213
ExportedProgram,
@@ -17,6 +18,7 @@
1718
)
1819
from torch.export.graph_signature import TensorArgument
1920
from torch.utils import _pytree as pytree
21+
from torchgen.model import SchemaKind
2022

2123

2224
def _insert_copy(
@@ -70,6 +72,44 @@ def _insert_copy(
7072
return buffer_output_nodes
7173

7274

75+
def _is_inplace_node(node: torch.fx.Node) -> bool:
76+
"""Check if a node is an inplace node."""
77+
return (
78+
node.op == "call_function"
79+
and isinstance(node.target, torch._ops.OpOverload)
80+
and is_inplace_variant(
81+
node.target._schema.name, node.target._schema.overload_name
82+
)
83+
)
84+
85+
86+
def _inplace_lineage(
87+
output_arg: torch.fx.Node,
88+
gm: torch.fx.GraphModule,
89+
gs: ExportGraphSignature,
90+
kind: SchemaKind,
91+
) -> bool:
92+
"""
93+
Walk the graph backwards to see if output_arg is ultimately the same as an input.
94+
"""
95+
if kind != OutputKind.BUFFER_MUTATION and kind != OutputKind.USER_INPUT_MUTATION:
96+
return False
97+
98+
while output_arg.op != "placeholder":
99+
if _is_inplace_node(output_arg):
100+
# From looking at native_functions.yaml, inplace ops always have self as the first arg
101+
output_arg = output_arg.args[0] # pyre-ignore
102+
else:
103+
return False
104+
105+
# If the output arg was a buffer then it needs to reach a buffer placeholder
106+
if kind == OutputKind.BUFFER_MUTATION:
107+
return output_arg.target in gs.inputs_to_buffers
108+
# If the output arg was a user input then it needs to reach a user input placeholder
109+
assert kind == OutputKind.USER_INPUT_MUTATION
110+
return output_arg.target in gs.user_inputs
111+
112+
73113
def insert_write_back_for_buffers_pass(
74114
ep: ExportedProgram,
75115
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
@@ -99,9 +139,15 @@ def insert_write_back_for_buffers_pass(
99139
if lifted_node is not None:
100140
input_name_to_node[lifted_node] = input_node
101141

142+
output_node = None
143+
for node in gm.graph.nodes:
144+
if node.op == "output":
145+
output_node = node
146+
break
147+
102148
# Grab the mutable buffer nodes in the outputs,
103149
mutated_outputs: List[Optional[str]] = []
104-
for out_spec in ep.graph_signature.output_specs:
150+
for i, out_spec in enumerate(ep.graph_signature.output_specs):
105151
# if the output arg is the input value then all operations on it are in-place
106152
# so there's no need to add a copy_ node
107153
if (
@@ -112,7 +158,12 @@ def insert_write_back_for_buffers_pass(
112158
out_spec.target in input_name_to_node
113159
and
114160
# if the arg and target are not the same, we add a copy_ node.
115-
out_spec.arg.name != input_name_to_node[out_spec.target].name
161+
not _inplace_lineage(
162+
output_node.args[0][i],
163+
gm,
164+
ep.graph_signature,
165+
ep.graph_signature.output_specs[i].kind,
166+
)
116167
):
117168
mutated_outputs.append(out_spec.target)
118169
else:

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ python_unittest(
145145
"//caffe2:torch",
146146
"//executorch/exir:lib",
147147
"//executorch/exir/passes:lib",
148+
"//executorch/extension/pybindings:portable_lib",
148149
],
149150
)
150151

0 commit comments

Comments
 (0)