Skip to content

Commit b3c9bf7

Browse files
expose reinplace pass in lowering
Differential Revision: D77323681 Pull Request resolved: #12044
1 parent 8dde918 commit b3c9bf7

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

exir/capture/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,6 @@ class ExecutorchBackendConfig:
105105

106106
# If set to true, we run quant fusion and constant propagation passes
107107
do_quant_fusion_and_const_prop: bool = False
108+
109+
# Experimental: If set to true, we run a pass to reinplace ops in the graph.
110+
run_reinplace_pass: bool = False

exir/program/_program.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
NormalizeViewCopyBasePass,
5959
)
6060
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
61+
from executorch.exir.passes.reinplace import reinplace_pass
6162
from executorch.exir.passes.remove_graph_asserts_pass import (
6263
RemoveGraphAssertsPass,
6364
RemoveNonCoreAtenOpGraphAssertsPass,
@@ -193,7 +194,7 @@ def _get_updated_graph_signature(
193194
)
194195
i += 1
195196

196-
output_node = list(new_gm.graph.nodes)[-1]
197+
output_node = new_gm.graph.output_node()
197198
assert output_node.op == "output"
198199

199200
new_output_specs = []
@@ -1563,6 +1564,8 @@ def to_executorch(
15631564
" Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig."
15641565
)
15651566
program = quant_fusion_and_const_prop_pass(program)
1567+
if config.run_reinplace_pass:
1568+
program = reinplace_pass(program)
15661569
program = weights_to_outputs_pass(program)
15671570
program = unsafe_remove_auto_functionalized_pass(program)
15681571
gm, new_signature = insert_write_back_for_buffers_pass(program)

exir/tests/test_reinplace_pass.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
import torch
1212
from executorch.exir import to_edge
13+
from executorch.exir.capture._config import ExecutorchBackendConfig
1314
from executorch.exir.passes.reinplace import reinplace_pass
15+
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
16+
_load_for_executorch_from_buffer,
17+
)
1418
from torch.export import export
1519

1620

@@ -35,8 +39,8 @@ def forward(
3539
values = torch.tensor([1.0])
3640

3741
exported_program = export(model, (indices, values), strict=True)
38-
print(exported_program.graph)
39-
edge_program = to_edge(exported_program).exported_program()
42+
edge = to_edge(exported_program)
43+
edge_program = edge._edge_programs["forward"]
4044

4145
# Find the index_put node
4246
index_put_node = None
@@ -47,16 +51,23 @@ def forward(
4751

4852
self.assertIsNotNone(index_put_node, "Should find an index_put node")
4953

50-
ep = reinplace_pass(edge_program)
54+
et = edge.to_executorch(ExecutorchBackendConfig(run_reinplace_pass=True))
5155
# Find the index_put node
5256
index_put_node = None
53-
for node in ep.graph.nodes:
57+
for node in et.exported_program().graph.nodes:
5458
if node.op == "call_function" and "index_put_" in str(node.target):
5559
index_put_node = node
5660
break
5761

5862
self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
5963

64+
e = _load_for_executorch_from_buffer(et.buffer)
65+
self.assertTrue(
66+
torch.allclose(
67+
e.forward((indices, values))[0], torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0])
68+
)
69+
)
70+
6071
def test_cant_reinplace(self) -> None:
6172
"""Test that index_put on a mutable buffer that is viewed later is not safe."""
6273

0 commit comments

Comments
 (0)