@@ -56,7 +56,12 @@ class PyPassManager {
5656
5757// / Create the `mlir.passmanager` here.
5858void mlir::python::populatePassManagerSubmodule (nb::module_ &m) {
59- constexpr const char *mlirExternalPassAttr = " __mlir_external_pass__" ;
59+ // ----------------------------------------------------------------------------
60+ // Mapping of MlirExternalPass
61+ // ----------------------------------------------------------------------------
62+ nb::class_<MlirExternalPass>(m, " ExternalPass" )
63+ .def (" signal_failure" ,
64+ [](MlirExternalPass pass) { mlirExternalPassSignalFailure (pass); });
6065
6166 // ----------------------------------------------------------------------------
6267 // Mapping of the top-level PassManager
@@ -186,27 +191,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
186191 };
187192 callbacks.run = [](MlirOperation op, MlirExternalPass pass,
188193 void *userData) {
189- auto callable =
190- nb::borrow<nb::callable>(static_cast <PyObject *>(userData));
191- nb::setattr (callable, mlirExternalPassAttr,
192- nb::capsule (pass.ptr ));
193- callable (op);
194- // delete it to avoid that it is used after
195- // the external pass is freed by the pass manager
196- nb::delattr (callable, mlirExternalPassAttr);
194+ nb::handle (static_cast <PyObject *>(userData))(op, pass);
197195 };
198- nb::setattr (run, " signal_pass_failure" , nb::cpp_function ([run]() {
199- nb::capsule cap;
200- try {
201- cap = run.attr (mlirExternalPassAttr);
202- } catch (nb::python_error &e) {
203- throw std::runtime_error (
204- " signal_pass_failure() should always be called "
205- " from the __call__ method" );
206- }
207- mlirExternalPassSignalFailure (
208- MlirExternalPass{cap.data ()});
209- }));
210196 auto externalPass = mlirCreateExternalPass (
211197 passID, mlirStringRefCreate (name->data (), name->length ()),
212198 mlirStringRefCreate (argument.data (), argument.length ()),
0 commit comments