|
22 | 22 | from executorch.exir.pass_base import ExportPass |
23 | 23 | from executorch.exir.passes import MemoryPlanningPass |
24 | 24 | from executorch.exir.program._program import ( |
| 25 | + _transform, |
25 | 26 | EdgeProgramManager, |
26 | 27 | ExecutorchProgramManager, |
27 | 28 | to_edge, |
|
34 | 35 | from executorch.extension.pybindings.portable_lib import ( |
35 | 36 | _load_for_executorch_from_buffer, |
36 | 37 | ) |
| 38 | +from torch._export.verifier import Verifier |
37 | 39 | from torch.export import Dim, export, ExportedProgram |
38 | 40 | from torch.export._trace import _export |
39 | 41 |
|
@@ -273,7 +275,6 @@ def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]: |
273 | 275 | for output_val in method.outputs: |
274 | 276 | evalue = method.values[output_val] |
275 | 277 | self.assertNotEqual(evalue.val.allocation_info, None) |
276 | | - else: |
277 | 278 | for input_val in method.inputs: |
278 | 279 | evalue = method.values[input_val] |
279 | 280 | self.assertEqual(evalue.val.allocation_info, None) |
@@ -847,3 +848,19 @@ def test_save_fails(self): |
847 | 848 | et = edge.to_executorch() |
848 | 849 | with self.assertRaises(ValueError): |
849 | 850 | _ = et.save("/tmp/test_save.pt") |
| 851 | + |
| 852 | + def test__transform_override_verifiers(self): |
| 853 | + """Test that _transform can override verifiers in the exported program.""" |
| 854 | + class MyVerifier(Verifier): |
| 855 | + dialect: str = "MY_DIALECT" |
| 856 | + def __init__(self): |
| 857 | + super().__init__() |
| 858 | + |
| 859 | + model = TestLinear() |
| 860 | + program = torch.export.export(model, model._get_random_inputs(), strict=True) |
| 861 | + self.assertFalse(issubclass(program.verifiers[0], MyVerifier)) |
| 862 | + |
| 863 | + # Apply transformation with custom verifier |
| 864 | + transformed = _transform(program, AddToMulPassEdge(), override_verifiers=[MyVerifier]) |
| 865 | + self.assertTrue(issubclass(transformed.verifiers[0], MyVerifier)) |
| 866 | + self.assertFalse(issubclass(program.verifiers[0], MyVerifier)) |
0 commit comments