diff --git a/exir/capture/_config.py b/exir/capture/_config.py index a0757cfeb33..80a838737fc 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -105,3 +105,6 @@ class ExecutorchBackendConfig: # If set to true, we run quant fusion and constant propagation passes do_quant_fusion_and_const_prop: bool = False + + # Experimental: If set to true, we run a pass to reinplace ops in the graph. + run_reinplace_pass: bool = False diff --git a/exir/program/_program.py b/exir/program/_program.py index b9fa83a668f..0c4469c96de 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -58,6 +58,7 @@ NormalizeViewCopyBasePass, ) from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass +from executorch.exir.passes.reinplace import reinplace_pass from executorch.exir.passes.remove_graph_asserts_pass import ( RemoveGraphAssertsPass, RemoveNonCoreAtenOpGraphAssertsPass, @@ -193,7 +194,7 @@ def _get_updated_graph_signature( ) i += 1 - output_node = list(new_gm.graph.nodes)[-1] + output_node = new_gm.graph.output_node() assert output_node.op == "output" new_output_specs = [] @@ -1563,6 +1564,8 @@ def to_executorch( " Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig." ) program = quant_fusion_and_const_prop_pass(program) + if config.run_reinplace_pass: + program = reinplace_pass(program) program = weights_to_outputs_pass(program) program = unsafe_remove_auto_functionalized_pass(program) gm, new_signature = insert_write_back_for_buffers_pass(program) diff --git a/exir/tests/test_reinplace_pass.py b/exir/tests/test_reinplace_pass.py index 2f4538770d6..8488b152398 100644 --- a/exir/tests/test_reinplace_pass.py +++ b/exir/tests/test_reinplace_pass.py @@ -10,7 +10,11 @@ import torch from executorch.exir import to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.reinplace import reinplace_pass +from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib + _load_for_executorch_from_buffer, +) from torch.export import export @@ -35,8 +39,8 @@ def forward( values = torch.tensor([1.0]) exported_program = export(model, (indices, values), strict=True) - print(exported_program.graph) - edge_program = to_edge(exported_program).exported_program() + edge = to_edge(exported_program) + edge_program = edge._edge_programs["forward"] # Find the index_put node index_put_node = None @@ -47,16 +51,23 @@ def forward( self.assertIsNotNone(index_put_node, "Should find an index_put node") - ep = reinplace_pass(edge_program) + et = edge.to_executorch(ExecutorchBackendConfig(run_reinplace_pass=True)) # Find the index_put node index_put_node = None - for node in ep.graph.nodes: + for node in et.exported_program().graph.nodes: if node.op == "call_function" and "index_put_" in str(node.target): index_put_node = node break self.assertIsNotNone(index_put_node, "Should find an index_put_ node") + e = _load_for_executorch_from_buffer(et.buffer) + self.assertTrue( + torch.allclose( + e.forward((indices, values))[0], torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0]) + ) + ) + def test_cant_reinplace(self) -> None: """Test that index_put on a mutable buffer that is viewed later is not safe."""