| 
22 | 22 | from executorch.exir.pass_base import ExportPass  | 
23 | 23 | from executorch.exir.passes import MemoryPlanningPass  | 
24 | 24 | from executorch.exir.program._program import (  | 
 | 25 | +    _transform,  | 
25 | 26 |     EdgeProgramManager,  | 
26 | 27 |     ExecutorchProgramManager,  | 
27 | 28 |     to_edge,  | 
 | 
34 | 35 | from executorch.extension.pybindings.portable_lib import (  | 
35 | 36 |     _load_for_executorch_from_buffer,  | 
36 | 37 | )  | 
 | 38 | +from torch._export.verifier import Verifier  | 
37 | 39 | from torch.export import Dim, export, ExportedProgram  | 
38 | 40 | from torch.export._trace import _export  | 
39 | 41 | 
 
  | 
@@ -273,7 +275,6 @@ def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]:  | 
273 | 275 |             for output_val in method.outputs:  | 
274 | 276 |                 evalue = method.values[output_val]  | 
275 | 277 |                 self.assertNotEqual(evalue.val.allocation_info, None)  | 
276 |  | -        else:  | 
277 | 278 |             for input_val in method.inputs:  | 
278 | 279 |                 evalue = method.values[input_val]  | 
279 | 280 |                 self.assertEqual(evalue.val.allocation_info, None)  | 
@@ -847,3 +848,19 @@ def test_save_fails(self):  | 
847 | 848 |         et = edge.to_executorch()  | 
848 | 849 |         with self.assertRaises(ValueError):  | 
849 | 850 |             _ = et.save("/tmp/test_save.pt")  | 
 | 851 | + | 
 | 852 | +    def test__transform_override_verifiers(self):  | 
 | 853 | +        """Test that _transform can override verifiers in the exported program."""  | 
 | 854 | +        class MyVerifier(Verifier):  | 
 | 855 | +            dialect: str = "MY_DIALECT"  | 
 | 856 | +            def __init__(self):  | 
 | 857 | +                super().__init__()  | 
 | 858 | + | 
 | 859 | +        model = TestLinear()  | 
 | 860 | +        program = torch.export.export(model, model._get_random_inputs(), strict=True)  | 
 | 861 | +        self.assertFalse(issubclass(program.verifiers[0], MyVerifier))  | 
 | 862 | + | 
 | 863 | +        # Apply transformation with custom verifier  | 
 | 864 | +        transformed = _transform(program, AddToMulPassEdge(), override_verifiers=[MyVerifier])  | 
 | 865 | +        self.assertTrue(issubclass(transformed.verifiers[0], MyVerifier))  | 
 | 866 | +        self.assertFalse(issubclass(program.verifiers[0], MyVerifier))  | 
0 commit comments