1010
1111import torch
1212from executorch .exir import to_edge
13+ from executorch .exir .capture ._config import ExecutorchBackendConfig
1314from executorch .exir .passes .reinplace import reinplace_pass
15+ from executorch .extension .pybindings .portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
16+ _load_for_executorch_from_buffer ,
17+ )
1418from torch .export import export
1519
1620
@@ -35,8 +39,8 @@ def forward(
3539 values = torch .tensor ([1.0 ])
3640
3741 exported_program = export (model , (indices , values ), strict = True )
38- print (exported_program . graph )
39- edge_program = to_edge ( exported_program ). exported_program ()
42+ edge = to_edge (exported_program )
43+ edge_program = edge . _edge_programs [ "forward" ]
4044
4145 # Find the index_put node
4246 index_put_node = None
@@ -47,16 +51,23 @@ def forward(
4751
4852 self .assertIsNotNone (index_put_node , "Should find an index_put node" )
4953
50- ep = reinplace_pass ( edge_program )
54+ et = edge . to_executorch ( ExecutorchBackendConfig ( run_reinplace_pass = True ) )
5155 # Find the index_put node
5256 index_put_node = None
53- for node in ep .graph .nodes :
57+ for node in et . exported_program () .graph .nodes :
5458 if node .op == "call_function" and "index_put_" in str (node .target ):
5559 index_put_node = node
5660 break
5761
5862 self .assertIsNotNone (index_put_node , "Should find an index_put_ node" )
5963
64+ e = _load_for_executorch_from_buffer (et .buffer )
65+ self .assertTrue (
66+ torch .allclose (
67+ e .forward ((indices , values ))[0 ], torch .tensor ([1.0 , 0.0 , 0.0 , 0.0 , 0.0 ])
68+ )
69+ )
70+
6071 def test_cant_reinplace (self ) -> None :
6172 """Test that index_put on a mutable buffer that is viewed later is not safe."""
6273
0 commit comments