Skip to content

Commit 7554e85

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Reinplace.py (pytorch#11918)
Summary: Pass attempts to reinplace index_put if it is safe to do so. Differential Revision: D77204122
1 parent 752f6a7 commit 7554e85

File tree

4 files changed

+226
-0
lines changed

4 files changed

+226
-0
lines changed

exir/passes/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python_library(
3131
":sym_shape_eval_pass",
3232
":sym_to_tensor_pass",
3333
":weights_to_outputs_pass",
34+
":reinplace_pass",
3435
"//caffe2:torch",
3536
"//executorch/exir:common",
3637
"//executorch/exir:control_flow",
@@ -68,6 +69,17 @@ python_library(
6869
],
6970
)
7071

72+
python_library(
73+
name = "reinplace_pass",
74+
srcs = [
75+
"reinplace.py",
76+
],
77+
deps = [
78+
"//caffe2:torch",
79+
"//executorch/exir/dialects:lib",
80+
],
81+
)
82+
7183
python_library(
7284
name = "insert_write_back_for_buffers_pass",
7385
srcs = [

exir/passes/reinplace.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops
13+
from torch.export import ExportedProgram
14+
15+
def _is_index_put(node: torch.fx.Node) -> bool:
16+
"""Check if a node is an index_put operation."""
17+
return node.op == "call_function" and node.target in (
18+
torch.ops.aten.index_put.default,
19+
ops.edge.aten.index_put.default,
20+
)
21+
22+
23+
def _is_safe_to_reinplace(
24+
node: torch.fx.Node,
25+
later_nodes: Set[torch.fx.Node],
26+
inputs: Set[torch.fx.Node],
27+
mutable_inputs: Set[torch.fx.Node],
28+
) -> bool:
29+
# This node is used later in the graph so we can't reinplace it
30+
# There is probably a faster way to do this but this works for now.
31+
if node in later_nodes:
32+
return False
33+
34+
# If its not an input then we can reinplace it
35+
if node not in inputs:
36+
return True
37+
# If its a mutable input then we can reinplace it
38+
elif node in mutable_inputs:
39+
return True
40+
else: # input but not mutable input
41+
return False
42+
43+
44+
45+
def _is_user_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
46+
return node.target in exported_program.graph_signature.user_inputs
47+
48+
def _is_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
49+
return node.target in exported_program.graph_signature.inputs_to_buffers
50+
51+
def _is_mutable_user_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
52+
return node.target in exported_program.graph_signature.user_inputs_to_mutate.values()
53+
54+
def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
55+
if node.target not in exported_program.graph_signature.inputs_to_buffers:
56+
return False
57+
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
58+
return buf in exported_program.graph_signature.buffers_to_mutate.values()
59+
60+
def reinplace_pass(ep: ExportedProgram) -> ExportedProgram:
61+
"""
62+
Pass that loops over nodes in an exported program and collects the first argument
63+
of every call_function node that is a view_copy operation.
64+
65+
Args:
66+
exported_program: The ExportedProgram to analyze
67+
68+
Returns:
69+
Set of nodes that are first arguments to view_copy operations
70+
"""
71+
seen_nodes: Set[torch.fx.Node] = set()
72+
# Get all placeholders
73+
placeholders = set()
74+
for node in ep.graph.nodes:
75+
if node.op == "placeholder":
76+
placeholders.add(node)
77+
# Get all inputs that we could potentially mutate
78+
inputs = set([node for node in placeholders if _is_user_input(node, ep) or _is_buffer(node, ep)])
79+
mutable_nodes = set([node for node in placeholders if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep)])
80+
81+
results = set()
82+
for node in reversed(ep.graph.nodes):
83+
if _is_index_put(node):
84+
# Check if this index_put node is safe to inplace
85+
# The first argument is the base tensor being indexed into
86+
first_arg = node.args[0]
87+
if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes):
88+
# This index_put is safe to reinplace
89+
with ep.graph.inserting_before(node):
90+
new_node = ep.graph.call_function(
91+
ops.edge.aten.index_put_.default, args=node.args
92+
)
93+
new_node.meta["val"] = node.meta["val"]
94+
node.replace_all_uses_with(new_node)
95+
ep.graph.erase_node(node)
96+
results.add(first_arg)
97+
elif node.op == "call_function":
98+
seen_nodes.update(node.all_input_nodes)
99+
100+
return ep

exir/tests/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@ python_unittest(
136136
],
137137
)
138138

139+
python_unittest(
140+
name = "reinplace_pass",
141+
srcs = [
142+
"test_reinplace_pass.py",
143+
],
144+
deps = [
145+
"//caffe2:torch",
146+
"//executorch/exir:lib",
147+
"//executorch/exir/passes:lib",
148+
],
149+
)
150+
139151
cpp_library(
140152
name = "test_lib",
141153
srcs = [

exir/tests/test_reinplace_pass.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
from executorch.exir import to_edge
13+
from executorch.exir.passes.reinplace import (
14+
reinplace_pass,
15+
)
16+
from torch.export import export
17+
from torch.export.graph_signature import InputKind, OutputKind
18+
19+
20+
class TestReinplacePass(unittest.TestCase):
21+
def test_index_put_reinplace(self) -> None:
22+
"""Test that index_put on a mutable buffer can be reinplaced."""
23+
24+
class IndexPutModel(torch.nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self.register_buffer("state", torch.zeros(5))
28+
29+
def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
30+
# index_put on buffer (non-user input) should be safe
31+
self.state.index_put_((indices,), values)
32+
return self.state
33+
34+
model = IndexPutModel()
35+
indices = torch.tensor([0])
36+
values = torch.tensor([1.0])
37+
38+
exported_program = export(model, (indices, values), strict=True)
39+
print(exported_program.graph)
40+
edge_program = to_edge(exported_program).exported_program()
41+
42+
# Find the index_put node
43+
index_put_node = None
44+
for node in edge_program.graph.nodes:
45+
if (node.op == "call_function" and
46+
"index_put" in str(node.target)):
47+
index_put_node = node
48+
break
49+
50+
self.assertIsNotNone(index_put_node, "Should find an index_put node")
51+
52+
ep = reinplace_pass(edge_program)
53+
# Find the index_put node
54+
index_put_node = None
55+
for node in ep.graph.nodes:
56+
if (node.op == "call_function" and
57+
"index_put_" in str(node.target)):
58+
index_put_node = node
59+
break
60+
61+
self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
62+
def test_cant_reinplace(self) -> None:
63+
"""Test that index_put on a mutable buffer that is viewed later is not safe."""
64+
65+
class IndexPutModel(torch.nn.Module):
66+
def __init__(self):
67+
super().__init__()
68+
self.register_buffer("state", torch.zeros(5))
69+
70+
def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
71+
# index_put on buffer (non-user input) should be safe
72+
x = self.state.index_put((indices,), values)
73+
self.state.add_(1)
74+
return x
75+
76+
model = IndexPutModel()
77+
indices = torch.tensor([0])
78+
values = torch.tensor([1.0])
79+
80+
exported_program = export(model, (indices, values), strict=True)
81+
edge_program = to_edge(exported_program).exported_program()
82+
83+
# Find the index_put node
84+
index_put_node = None
85+
for node in edge_program.graph.nodes:
86+
if (node.op == "call_function" and
87+
"index_put" in str(node.target)):
88+
index_put_node = node
89+
break
90+
91+
self.assertIsNotNone(index_put_node, "Should find an index_put node")
92+
93+
ep = reinplace_pass(edge_program)
94+
# Find the index_put node
95+
index_put_node = None
96+
for node in ep.graph.nodes:
97+
if (node.op == "call_function" and
98+
"index_put" in str(node.target) and "index_put_" not in str(node.target)):
99+
index_put_node = node
100+
break
101+
102+
self.assertIsNotNone(index_put_node, "Should still find an index_put node")

0 commit comments

Comments
 (0)