diff --git a/exir/program/_program.py b/exir/program/_program.py index e0484f4f4ff..dc6a3f683ed 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -234,8 +234,6 @@ def _transform( isinstance(p, (list, Verifier)) for p in passes ), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}" - for p in list(passes): - print(type(p)) pm = PassManager(list(passes)) res = pm(self.graph_module) transformed_gm = res.graph_module if res is not None else self.graph_module @@ -1442,22 +1440,34 @@ def transform( """ compile_config = compile_config or self.compile_config new_programs: Dict[str, ExportedProgram] = {} + + def _transform_and_verify( + program: ExportedProgram, + passes: Sequence[PassType], + verifier: type[Verifier], + ) -> ExportedProgram: + # Overwrite the original verifier with the new one + # This should be a no-op for the most cases where compile_config is none. + new_program = _transform(program, *passes, override_verifiers=[verifier]) + # ExportedProgram constructor should call the verifier, but + # the validate() function in the constructor is marked for deprecation. + verifier()(new_program.graph_module) + return new_program + + verifier = EXIREdgeDialectVerifier( + edge_compile_config=compile_config, class_only=True + ) if isinstance(passes, dict): for name, program in self._edge_programs.items(): if name in passes.keys(): - new_programs[name] = _transform(program, *passes[name]) - EXIREdgeDialectVerifier(edge_compile_config=compile_config)( - new_programs[name].graph_module + new_programs[name] = _transform_and_verify( + program, passes[name], verifier ) else: new_programs[name] = copy.deepcopy(program) - else: # apply passes to every method for name, program in self._edge_programs.items(): - new_programs[name] = _transform(program, *passes) - EXIREdgeDialectVerifier(edge_compile_config=compile_config)( - new_programs[name].graph_module - ) + new_programs[name] = _transform_and_verify(program, passes, verifier) return EdgeProgramManager( new_programs, copy.deepcopy(self._config_methods), compile_config