|
10 | 10 |
|
11 | 11 | import torch |
12 | 12 | from executorch.exir import to_edge |
| 13 | +from executorch.exir.capture._config import ExecutorchBackendConfig |
13 | 14 | from executorch.exir.passes.reinplace import reinplace_pass |
| 15 | +from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib |
| 16 | + _load_for_executorch_from_buffer, |
| 17 | +) |
14 | 18 | from torch.export import export |
15 | 19 |
|
16 | 20 |
|
17 | 21 | class TestReinplacePass(unittest.TestCase): |
18 | 22 | def test_index_put_reinplace(self) -> None: |
19 | 23 | """Test that index_put on a mutable buffer can be reinplaced.""" |
20 | | - |
| 24 | + |
21 | 25 | class IndexPutModel(torch.nn.Module): |
22 | 26 | def __init__(self): |
23 | 27 | super().__init__() |
24 | 28 | self.register_buffer("state", torch.zeros(5)) |
25 | | - |
26 | | - def forward( |
27 | | - self, indices: torch.Tensor, values: torch.Tensor |
28 | | - ) -> torch.Tensor: |
| 29 | + |
| 30 | + def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor: |
29 | 31 | # index_put on buffer (non-user input) should be safe |
30 | 32 | self.state.index_put_((indices,), values) |
31 | 33 | return self.state |
32 | | - |
| 34 | + |
33 | 35 | model = IndexPutModel() |
34 | 36 | indices = torch.tensor([0]) |
35 | 37 | values = torch.tensor([1.0]) |
36 | | - |
| 38 | + |
37 | 39 | exported_program = export(model, (indices, values), strict=True) |
38 | | - print(exported_program.graph) |
39 | | - edge_program = to_edge(exported_program).exported_program() |
40 | | - |
| 40 | + edge = to_edge(exported_program) |
| 41 | + edge_program = edge._edge_programs["forward"] |
| 42 | + |
41 | 43 | # Find the index_put node |
42 | 44 | index_put_node = None |
43 | 45 | for node in edge_program.graph.nodes: |
44 | | - if node.op == "call_function" and "index_put" in str(node.target): |
| 46 | + if (node.op == "call_function" and |
| 47 | + "index_put" in str(node.target)): |
45 | 48 | index_put_node = node |
46 | 49 | break |
47 | | - |
| 50 | + |
48 | 51 | self.assertIsNotNone(index_put_node, "Should find an index_put node") |
49 | | - |
50 | | - ep = reinplace_pass(edge_program) |
| 52 | + |
| 53 | + et = edge.to_executorch(ExecutorchBackendConfig(run_reinplace_pass=True)) |
51 | 54 | # Find the index_put node |
52 | 55 | index_put_node = None |
53 | | - for node in ep.graph.nodes: |
54 | | - if node.op == "call_function" and "index_put_" in str(node.target): |
| 56 | + for node in et.exported_program().graph.nodes: |
| 57 | + if (node.op == "call_function" and |
| 58 | + "index_put_" in str(node.target)): |
55 | 59 | index_put_node = node |
56 | 60 | break |
57 | | - |
| 61 | + |
58 | 62 | self.assertIsNotNone(index_put_node, "Should find an index_put_ node") |
59 | 63 |
|
| 64 | + e = _load_for_executorch_from_buffer(et.buffer) |
| 65 | + self.assertTrue(torch.allclose(e.forward((indices, values))[0], torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0]))) |
| 66 | + |
60 | 67 | def test_cant_reinplace(self) -> None: |
61 | 68 | """Test that index_put on a mutable buffer that is viewed later is not safe.""" |
62 | 69 |
|
|
0 commit comments