Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ class ExecutorchBackendConfig:

# If set to true, we run quant fusion and constant propagation passes
do_quant_fusion_and_const_prop: bool = False

# Experimental: If set to true, we run a pass to reinplace ops in the graph.
run_reinplace_pass: bool = False
5 changes: 4 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
NormalizeViewCopyBasePass,
)
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
from executorch.exir.passes.reinplace import reinplace_pass
from executorch.exir.passes.remove_graph_asserts_pass import (
RemoveGraphAssertsPass,
RemoveNonCoreAtenOpGraphAssertsPass,
Expand Down Expand Up @@ -193,7 +194,7 @@ def _get_updated_graph_signature(
)
i += 1

output_node = list(new_gm.graph.nodes)[-1]
output_node = new_gm.graph.output_node()
assert output_node.op == "output"

new_output_specs = []
Expand Down Expand Up @@ -1563,6 +1564,8 @@ def to_executorch(
" Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig."
)
program = quant_fusion_and_const_prop_pass(program)
if config.run_reinplace_pass:
program = reinplace_pass(program)
program = weights_to_outputs_pass(program)
program = unsafe_remove_auto_functionalized_pass(program)
gm, new_signature = insert_write_back_for_buffers_pass(program)
Expand Down
19 changes: 15 additions & 4 deletions exir/tests/test_reinplace_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

import torch
from executorch.exir import to_edge
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes.reinplace import reinplace_pass
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
_load_for_executorch_from_buffer,
)
from torch.export import export


Expand All @@ -35,8 +39,8 @@ def forward(
values = torch.tensor([1.0])

exported_program = export(model, (indices, values), strict=True)
print(exported_program.graph)
edge_program = to_edge(exported_program).exported_program()
edge = to_edge(exported_program)
edge_program = edge._edge_programs["forward"]

# Find the index_put node
index_put_node = None
Expand All @@ -47,16 +51,23 @@ def forward(

self.assertIsNotNone(index_put_node, "Should find an index_put node")

ep = reinplace_pass(edge_program)
et = edge.to_executorch(ExecutorchBackendConfig(run_reinplace_pass=True))
# Find the index_put node
index_put_node = None
for node in ep.graph.nodes:
for node in et.exported_program().graph.nodes:
if node.op == "call_function" and "index_put_" in str(node.target):
index_put_node = node
break

self.assertIsNotNone(index_put_node, "Should find an index_put_ node")

e = _load_for_executorch_from_buffer(et.buffer)
self.assertTrue(
torch.allclose(
e.forward((indices, values))[0], torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0])
)
)

def test_cant_reinplace(self) -> None:
"""Test that index_put on a mutable buffer that is viewed later is not safe."""

Expand Down
Loading