Skip to content

Commit 72b3bc9

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
expose reinplace pass in lowering
Summary: Not fully confidant in its stability so gate behind flag for now Reviewed By: lucylq Differential Revision: D77323681
1 parent b11075f commit 72b3bc9

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

exir/capture/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,6 @@ class ExecutorchBackendConfig:
105105

106106
# If set to true, we run quant fusion and constant propagation passes
107107
do_quant_fusion_and_const_prop: bool = False
108+
109+
# Experimental: If set to true, we run a pass to reinplace ops in the graph.
110+
run_reinplace_pass: bool = False

exir/program/_program.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
NormalizeViewCopyBasePass,
5959
)
6060
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
61+
from executorch.exir.passes.reinplace import reinplace_pass
6162
from executorch.exir.passes.remove_graph_asserts_pass import (
6263
RemoveGraphAssertsPass,
6364
RemoveNonCoreAtenOpGraphAssertsPass,
@@ -193,7 +194,7 @@ def _get_updated_graph_signature(
193194
)
194195
i += 1
195196

196-
output_node = list(new_gm.graph.nodes)[-1]
197+
output_node = new_gm.graph.output_node()
197198
assert output_node.op == "output"
198199

199200
new_output_specs = []
@@ -1563,6 +1564,8 @@ def to_executorch(
15631564
" Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig."
15641565
)
15651566
program = quant_fusion_and_const_prop_pass(program)
1567+
if config.run_reinplace_pass:
1568+
program = reinplace_pass(program)
15661569
program = weights_to_outputs_pass(program)
15671570
program = unsafe_remove_auto_functionalized_pass(program)
15681571
gm, new_signature = insert_write_back_for_buffers_pass(program)

exir/tests/test_reinplace_pass.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,53 +10,60 @@
1010

1111
import torch
1212
from executorch.exir import to_edge
13+
from executorch.exir.capture._config import ExecutorchBackendConfig
1314
from executorch.exir.passes.reinplace import reinplace_pass
15+
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
16+
_load_for_executorch_from_buffer,
17+
)
1418
from torch.export import export
1519

1620

1721
class TestReinplacePass(unittest.TestCase):
1822
def test_index_put_reinplace(self) -> None:
1923
"""Test that index_put on a mutable buffer can be reinplaced."""
20-
24+
2125
class IndexPutModel(torch.nn.Module):
2226
def __init__(self):
2327
super().__init__()
2428
self.register_buffer("state", torch.zeros(5))
25-
26-
def forward(
27-
self, indices: torch.Tensor, values: torch.Tensor
28-
) -> torch.Tensor:
29+
30+
def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
2931
# index_put on buffer (non-user input) should be safe
3032
self.state.index_put_((indices,), values)
3133
return self.state
32-
34+
3335
model = IndexPutModel()
3436
indices = torch.tensor([0])
3537
values = torch.tensor([1.0])
36-
38+
3739
exported_program = export(model, (indices, values), strict=True)
38-
print(exported_program.graph)
39-
edge_program = to_edge(exported_program).exported_program()
40-
40+
edge = to_edge(exported_program)
41+
edge_program = edge._edge_programs["forward"]
42+
4143
# Find the index_put node
4244
index_put_node = None
4345
for node in edge_program.graph.nodes:
44-
if node.op == "call_function" and "index_put" in str(node.target):
46+
if (node.op == "call_function" and
47+
"index_put" in str(node.target)):
4548
index_put_node = node
4649
break
47-
50+
4851
self.assertIsNotNone(index_put_node, "Should find an index_put node")
49-
50-
ep = reinplace_pass(edge_program)
52+
53+
et = edge.to_executorch(ExecutorchBackendConfig(run_reinplace_pass=True))
5154
# Find the index_put node
5255
index_put_node = None
53-
for node in ep.graph.nodes:
54-
if node.op == "call_function" and "index_put_" in str(node.target):
56+
for node in et.exported_program().graph.nodes:
57+
if (node.op == "call_function" and
58+
"index_put_" in str(node.target)):
5559
index_put_node = node
5660
break
57-
61+
5862
self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
5963

64+
e = _load_for_executorch_from_buffer(et.buffer)
65+
self.assertTrue(torch.allclose(e.forward((indices, values))[0], torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0])))
66+
6067
def test_cant_reinplace(self) -> None:
6168
"""Test that index_put on a mutable buffer that is viewed later is not safe."""
6269

0 commit comments

Comments
 (0)