Skip to content

Commit 0cd1b2c

Browse files
JacobSzwejbkahinriksnaer
authored andcommitted
Reinplace.py
Differential Revision: D77204122 Pull Request resolved: pytorch#11918
1 parent ac81e53 commit 0cd1b2c

File tree

4 files changed

+231
-0
lines changed

4 files changed

+231
-0
lines changed

exir/passes/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python_library(
3131
":sym_shape_eval_pass",
3232
":sym_to_tensor_pass",
3333
":weights_to_outputs_pass",
34+
":reinplace_pass",
3435
"//caffe2:torch",
3536
"//executorch/exir:common",
3637
"//executorch/exir:control_flow",
@@ -68,6 +69,17 @@ python_library(
6869
],
6970
)
7071

72+
python_library(
73+
name = "reinplace_pass",
74+
srcs = [
75+
"reinplace.py",
76+
],
77+
deps = [
78+
"//caffe2:torch",
79+
"//executorch/exir/dialects:lib",
80+
],
81+
)
82+
7183
python_library(
7284
name = "insert_write_back_for_buffers_pass",
7385
srcs = [

exir/passes/reinplace.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops
13+
from torch.export import ExportedProgram
14+
15+
16+
def _is_index_put(node: torch.fx.Node) -> bool:
17+
"""Check if a node is an index_put operation."""
18+
return node.op == "call_function" and node.target in (
19+
torch.ops.aten.index_put.default,
20+
ops.edge.aten.index_put.default,
21+
)
22+
23+
24+
def _is_safe_to_reinplace(
25+
node: torch.fx.Node,
26+
later_nodes: Set[torch.fx.Node],
27+
inputs: Set[torch.fx.Node],
28+
mutable_inputs: Set[torch.fx.Node],
29+
) -> bool:
30+
# This node is used later in the graph so we can't reinplace it
31+
# There is probably a faster way to do this but this works for now.
32+
if node in later_nodes:
33+
return False
34+
# If its not an input then we can reinplace it
35+
if node not in inputs:
36+
return True
37+
# If its a mutable input then we can reinplace it
38+
elif node in mutable_inputs:
39+
return True
40+
else: # input but not mutable input
41+
return False
42+
43+
44+
def _is_mutable_user_input(
45+
node: torch.fx.Node, exported_program: ExportedProgram
46+
) -> bool:
47+
return (
48+
node.target in exported_program.graph_signature.user_inputs_to_mutate.values()
49+
)
50+
51+
52+
def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
53+
if node.target not in exported_program.graph_signature.inputs_to_buffers:
54+
return False
55+
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
56+
return buf in exported_program.graph_signature.buffers_to_mutate.values()
57+
58+
59+
def reinplace_pass(ep: ExportedProgram) -> ExportedProgram:
60+
"""
61+
Pass that loops over nodes in an exported program and collects the first argument
62+
of every call_function node that is a view_copy operation.
63+
64+
Args:
65+
exported_program: The ExportedProgram to analyze
66+
67+
Returns:
68+
Set of nodes that are first arguments to view_copy operations
69+
"""
70+
seen_nodes: Set[torch.fx.Node] = set()
71+
# Get all placeholders
72+
inputs = set()
73+
for node in ep.graph.nodes:
74+
if node.op == "placeholder":
75+
inputs.add(node)
76+
# Get all inputs that we could potentially mutate
77+
mutable_nodes = set(
78+
[
79+
node
80+
for node in inputs
81+
if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep)
82+
]
83+
)
84+
85+
results = set()
86+
for node in reversed(ep.graph.nodes):
87+
if _is_index_put(node):
88+
# Check if this index_put node is safe to inplace
89+
# The first argument is the base tensor being indexed into
90+
first_arg = node.args[0]
91+
if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes):
92+
# This index_put is safe to reinplace
93+
with ep.graph.inserting_before(node):
94+
new_node = ep.graph.call_function(
95+
ops.edge.aten.index_put_.default, args=node.args
96+
)
97+
new_node.meta["val"] = node.meta["val"]
98+
node.replace_all_uses_with(new_node)
99+
ep.graph.erase_node(node)
100+
results.add(first_arg)
101+
elif node.op == "call_function":
102+
seen_nodes.update(node.all_input_nodes)
103+
return ep

exir/tests/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@ python_unittest(
136136
],
137137
)
138138

139+
python_unittest(
140+
name = "reinplace_pass",
141+
srcs = [
142+
"test_reinplace_pass.py",
143+
],
144+
deps = [
145+
"//caffe2:torch",
146+
"//executorch/exir:lib",
147+
"//executorch/exir/passes:lib",
148+
],
149+
)
150+
139151
cpp_library(
140152
name = "test_lib",
141153
srcs = [

exir/tests/test_reinplace_pass.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
from executorch.exir import to_edge
13+
from executorch.exir.passes.reinplace import reinplace_pass
14+
from torch.export import export
15+
16+
17+
class TestReinplacePass(unittest.TestCase):
18+
def test_index_put_reinplace(self) -> None:
19+
"""Test that index_put on a mutable buffer can be reinplaced."""
20+
21+
class IndexPutModel(torch.nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.register_buffer("state", torch.zeros(5))
25+
26+
def forward(
27+
self, indices: torch.Tensor, values: torch.Tensor
28+
) -> torch.Tensor:
29+
# index_put on buffer (non-user input) should be safe
30+
self.state.index_put_((indices,), values)
31+
return self.state
32+
33+
model = IndexPutModel()
34+
indices = torch.tensor([0])
35+
values = torch.tensor([1.0])
36+
37+
exported_program = export(model, (indices, values), strict=True)
38+
print(exported_program.graph)
39+
edge_program = to_edge(exported_program).exported_program()
40+
41+
# Find the index_put node
42+
index_put_node = None
43+
for node in edge_program.graph.nodes:
44+
if node.op == "call_function" and "index_put" in str(node.target):
45+
index_put_node = node
46+
break
47+
48+
self.assertIsNotNone(index_put_node, "Should find an index_put node")
49+
50+
ep = reinplace_pass(edge_program)
51+
# Find the index_put node
52+
index_put_node = None
53+
for node in ep.graph.nodes:
54+
if node.op == "call_function" and "index_put_" in str(node.target):
55+
index_put_node = node
56+
break
57+
58+
self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
59+
60+
def test_cant_reinplace(self) -> None:
61+
"""Test that index_put on a mutable buffer that is viewed later is not safe."""
62+
63+
class IndexPutModel(torch.nn.Module):
64+
def __init__(self):
65+
super().__init__()
66+
self.register_buffer("state", torch.zeros(5))
67+
68+
def forward(
69+
self, indices: torch.Tensor, values: torch.Tensor
70+
) -> torch.Tensor:
71+
# index_put on buffer (non-user input) should be safe
72+
x = self.state.index_put((indices,), values)
73+
self.state.add_(1)
74+
return x
75+
76+
model = IndexPutModel()
77+
indices = torch.tensor([0])
78+
values = torch.tensor([1.0])
79+
80+
exported_program = export(model, (indices, values), strict=True)
81+
edge_program = to_edge(exported_program).exported_program()
82+
83+
# Find the index_put node
84+
index_put_node = None
85+
for node in edge_program.graph.nodes:
86+
if node.op == "call_function" and "index_put" in str(node.target):
87+
index_put_node = node
88+
break
89+
90+
self.assertIsNotNone(index_put_node, "Should find an index_put node")
91+
92+
ep = reinplace_pass(edge_program)
93+
# Find the index_put node
94+
index_put_node = None
95+
for node in ep.graph.nodes:
96+
if (
97+
node.op == "call_function"
98+
and "index_put" in str(node.target)
99+
and "index_put_" not in str(node.target)
100+
):
101+
index_put_node = node
102+
break
103+
104+
self.assertIsNotNone(index_put_node, "Should still find an index_put node")

0 commit comments

Comments
 (0)