10
10
11
11
import torch
12
12
from executorch .exir import to_edge
13
+ from executorch .exir .capture ._config import ExecutorchBackendConfig
13
14
from executorch .exir .passes .reinplace import reinplace_pass
15
+ from executorch .extension .pybindings .portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
16
+ _load_for_executorch_from_buffer ,
17
+ )
14
18
from torch .export import export
15
19
16
20
@@ -35,8 +39,8 @@ def forward(
35
39
values = torch .tensor ([1.0 ])
36
40
37
41
exported_program = export (model , (indices , values ), strict = True )
38
- print (exported_program . graph )
39
- edge_program = to_edge ( exported_program ). exported_program ()
42
+ edge = to_edge (exported_program )
43
+ edge_program = edge . _edge_programs [ "forward" ]
40
44
41
45
# Find the index_put node
42
46
index_put_node = None
@@ -47,16 +51,23 @@ def forward(
47
51
48
52
self .assertIsNotNone (index_put_node , "Should find an index_put node" )
49
53
50
- ep = reinplace_pass ( edge_program )
54
+ et = edge . to_executorch ( ExecutorchBackendConfig ( run_reinplace_pass = True ) )
51
55
# Find the index_put node
52
56
index_put_node = None
53
- for node in ep .graph .nodes :
57
+ for node in et . exported_program () .graph .nodes :
54
58
if node .op == "call_function" and "index_put_" in str (node .target ):
55
59
index_put_node = node
56
60
break
57
61
58
62
self .assertIsNotNone (index_put_node , "Should find an index_put_ node" )
59
63
64
+ e = _load_for_executorch_from_buffer (et .buffer )
65
+ self .assertTrue (
66
+ torch .allclose (
67
+ e .forward ((indices , values ))[0 ], torch .tensor ([1.0 , 0.0 , 0.0 , 0.0 , 0.0 ])
68
+ )
69
+ )
70
+
60
71
def test_cant_reinplace (self ) -> None :
61
72
"""Test that index_put on a mutable buffer that is viewed later is not safe."""
62
73
0 commit comments