Skip to content

Commit d1d63b4

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Reinplace.py
Summary: Pass attempts to reinplace index_put if it is safe to do so. Differential Revision: D77204122
1 parent 752f6a7 commit d1d63b4

File tree

4 files changed

+235
-0
lines changed

4 files changed

+235
-0
lines changed

exir/passes/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python_library(
3131
":sym_shape_eval_pass",
3232
":sym_to_tensor_pass",
3333
":weights_to_outputs_pass",
34+
":reinplace_pass",
3435
"//caffe2:torch",
3536
"//executorch/exir:common",
3637
"//executorch/exir:control_flow",
@@ -68,6 +69,17 @@ python_library(
6869
],
6970
)
7071

72+
python_library(
73+
name = "reinplace_pass",
74+
srcs = [
75+
"reinplace.py",
76+
],
77+
deps = [
78+
"//caffe2:torch",
79+
"//executorch/exir/dialects:lib",
80+
],
81+
)
82+
7183
python_library(
7284
name = "insert_write_back_for_buffers_pass",
7385
srcs = [

exir/passes/reinplace.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops
13+
from torch.export import ExportedProgram
14+
15+
16+
def _is_view_copy(node: torch.fx.Node) -> bool:
17+
"""Check if a node is a view_copy operation."""
18+
return node.op == "call_function" and node.target in (
19+
torch.ops.aten.view_copy.default,
20+
ops.edge.aten.view_copy.default,
21+
)
22+
23+
24+
def _is_index_put(node: torch.fx.Node) -> bool:
25+
"""Check if a node is an index_put operation."""
26+
return node.op == "call_function" and node.target in (
27+
torch.ops.aten.index_put.default,
28+
ops.edge.aten.index_put.default,
29+
)
30+
31+
32+
def _is_safe_to_reinplace(
33+
node: torch.fx.Node,
34+
later_nodes: Set[torch.fx.Node],
35+
inputs: Set[torch.fx.Node],
36+
mutable_inputs: Set[torch.fx.Node],
37+
) -> bool:
38+
# This node is used later in the graph so we can't reinplace it
39+
# There is probably a faster way to do this but this works for now.
40+
if node in later_nodes:
41+
return False
42+
43+
# If its not an input then we can reinplace it
44+
if node not in inputs:
45+
return True
46+
# If its a mutable input then we can reinplace it
47+
elif node in mutable_inputs:
48+
return True
49+
else: # input but not mutable input
50+
return False
51+
52+
53+
54+
def _is_user_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
55+
return node.target in exported_program.graph_signature.user_inputs
56+
57+
def _is_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
58+
return node.target in exported_program.graph_signature.inputs_to_buffers
59+
60+
def _is_mutable_user_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
61+
return node.target in exported_program.graph_signature.user_inputs_to_mutate.values()
62+
63+
def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
64+
if node.target not in exported_program.graph_signature.inputs_to_buffers:
65+
return False
66+
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
67+
return buf in exported_program.graph_signature.buffers_to_mutate.values()
68+
69+
def reinplace_pass(ep: ExportedProgram) -> ExportedProgram:
70+
"""
71+
Pass that loops over nodes in an exported program and collects the first argument
72+
of every call_function node that is a view_copy operation.
73+
74+
Args:
75+
exported_program: The ExportedProgram to analyze
76+
77+
Returns:
78+
Set of nodes that are first arguments to view_copy operations
79+
"""
80+
seen_nodes: Set[torch.fx.Node] = set()
81+
# Get all placeholders
82+
placeholders = set()
83+
for node in ep.graph.nodes:
84+
if node.op == "placeholder":
85+
placeholders.add(node)
86+
# Get all inputs that we could potentially mutate
87+
inputs = set([node for node in placeholders if _is_user_input(node, ep) or _is_buffer(node, ep)])
88+
mutable_nodes = set([node for node in placeholders if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep)])
89+
90+
results = set()
91+
for node in reversed(ep.graph.nodes):
92+
if _is_index_put(node):
93+
# Check if this index_put node is safe to inplace
94+
# The first argument is the base tensor being indexed into
95+
first_arg = node.args[0]
96+
if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes):
97+
# This index_put is safe to reinplace
98+
with ep.graph.inserting_before(node):
99+
new_node = ep.graph.call_function(
100+
ops.edge.aten.index_put_.default, args=node.args
101+
)
102+
new_node.meta["val"] = node.meta["val"]
103+
node.replace_all_uses_with(new_node)
104+
ep.graph.erase_node(node)
105+
results.add(first_arg)
106+
elif node.op == "call_function":
107+
seen_nodes.update(node.all_input_nodes)
108+
109+
return ep

exir/tests/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@ python_unittest(
136136
],
137137
)
138138

139+
python_unittest(
140+
name = "reinplace_pass",
141+
srcs = [
142+
"test_reinplace_pass.py",
143+
],
144+
deps = [
145+
"//caffe2:torch",
146+
"//executorch/exir:lib",
147+
"//executorch/exir/passes:lib",
148+
],
149+
)
150+
139151
cpp_library(
140152
name = "test_lib",
141153
srcs = [

exir/tests/test_reinplace_pass.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
from executorch.exir import to_edge
13+
from executorch.exir.passes.reinplace import (
14+
reinplace_pass,
15+
)
16+
from torch.export import export
17+
from torch.export.graph_signature import InputKind, OutputKind
18+
19+
20+
class TestReinplacePass(unittest.TestCase):
21+
def test_index_put_reinplace(self) -> None:
22+
"""Test that index_put on a mutable buffer can be reinplaced."""
23+
24+
class IndexPutModel(torch.nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self.register_buffer("state", torch.zeros(5))
28+
29+
def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
30+
# index_put on buffer (non-user input) should be safe
31+
self.state.index_put_((indices,), values)
32+
return self.state
33+
34+
model = IndexPutModel()
35+
indices = torch.tensor([0])
36+
values = torch.tensor([1.0])
37+
38+
exported_program = export(model, (indices, values), strict=True)
39+
print(exported_program.graph)
40+
edge_program = to_edge(exported_program).exported_program()
41+
42+
# Find the index_put node
43+
index_put_node = None
44+
for node in edge_program.graph.nodes:
45+
if (node.op == "call_function" and
46+
"index_put" in str(node.target)):
47+
index_put_node = node
48+
break
49+
50+
self.assertIsNotNone(index_put_node, "Should find an index_put node")
51+
52+
ep = reinplace_pass(edge_program)
53+
# Find the index_put node
54+
index_put_node = None
55+
for node in ep.graph.nodes:
56+
if (node.op == "call_function" and
57+
"index_put_" in str(node.target)):
58+
index_put_node = node
59+
break
60+
61+
self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
62+
def test_cant_reinplace(self) -> None:
63+
"""Test that index_put on a mutable buffer that is viewed later is not safe."""
64+
65+
class IndexPutModel(torch.nn.Module):
66+
def __init__(self):
67+
super().__init__()
68+
self.register_buffer("state", torch.zeros(5))
69+
70+
def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
71+
# index_put on buffer (non-user input) should be safe
72+
x = self.state.index_put((indices,), values)
73+
self.state.add_(1)
74+
return x
75+
76+
model = IndexPutModel()
77+
indices = torch.tensor([0])
78+
values = torch.tensor([1.0])
79+
80+
exported_program = export(model, (indices, values), strict=True)
81+
edge_program = to_edge(exported_program).exported_program()
82+
83+
# Find the index_put node
84+
index_put_node = None
85+
for node in edge_program.graph.nodes:
86+
if (node.op == "call_function" and
87+
"index_put" in str(node.target)):
88+
index_put_node = node
89+
break
90+
91+
self.assertIsNotNone(index_put_node, "Should find an index_put node")
92+
93+
ep = reinplace_pass(edge_program)
94+
# Find the index_put node
95+
index_put_node = None
96+
for node in ep.graph.nodes:
97+
if (node.op == "call_function" and
98+
"index_put" in str(node.target) and "index_put_" not in str(node.target)):
99+
index_put_node = node
100+
break
101+
102+
self.assertIsNotNone(index_put_node, "Should still find an index_put node")

0 commit comments

Comments
 (0)