@@ -52,6 +52,24 @@ class PyPassManager {
5252 MlirPassManager passManager;
5353};
5454
55+ MlirExternalPassCallbacks createExternalPassCallbacksForPythonCallable () {
56+ MlirExternalPassCallbacks callbacks;
57+ callbacks.construct = [](void *obj) {
58+ (void )nb::handle (static_cast <PyObject *>(obj)).inc_ref ();
59+ };
60+ callbacks.destruct = [](void *obj) {
61+ (void )nb::handle (static_cast <PyObject *>(obj)).dec_ref ();
62+ };
63+ callbacks.initialize = nullptr ;
64+ callbacks.clone = [](void *) -> void * {
65+ throw std::runtime_error (" Cloning Python passes not supported" );
66+ };
67+ callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) {
68+ nb::handle (static_cast <PyObject *>(userData))(op, pass);
69+ };
70+ return callbacks;
71+ }
72+
5573} // namespace
5674
5775// / Create the `mlir.passmanager` here.
@@ -63,6 +81,33 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
6381 .def (" signal_pass_failure" ,
6482 [](MlirExternalPass pass) { mlirExternalPassSignalFailure (pass); });
6583
84+ // ----------------------------------------------------------------------------
85+ // Mapping of register_pass
86+ // ----------------------------------------------------------------------------
87+ m.def (
88+ " register_pass" ,
89+ [](const std::string &argument, const nb::callable &run,
90+ std::optional<std::string> &name, const std::string &description,
91+ const std::string &opName) {
92+ if (!name.has_value ()) {
93+ name =
94+ nb::cast<std::string>(nb::borrow<nb::str>(run.attr (" __name__" )));
95+ }
96+ MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate ();
97+ MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID (typeIDAllocator);
98+ auto callbacks = createExternalPassCallbacksForPythonCallable ();
99+ mlirRegisterExternalPass (
100+ passID, mlirStringRefCreate (name->data (), name->length ()),
101+ mlirStringRefCreate (argument.data (), argument.length ()),
102+ mlirStringRefCreate (description.data (), description.length ()),
103+ mlirStringRefCreate (opName.data (), opName.size ()),
104+ /* nDependentDialects*/ 0 , /* dependentDialects*/ nullptr , callbacks,
105+ /* userData*/ run.ptr ());
106+ },
107+ " argument" _a, " run" _a, " name" _a.none () = nb::none (),
108+ " description" _a.none () = " " , " op_name" _a.none () = " " ,
109+ " Register a python-defined pass." );
110+
66111 // ----------------------------------------------------------------------------
67112 // Mapping of the top-level PassManager
68113 // ----------------------------------------------------------------------------
@@ -178,21 +223,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
178223 MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate ();
179224 MlirTypeID passID =
180225 mlirTypeIDAllocatorAllocateTypeID (typeIDAllocator);
181- MlirExternalPassCallbacks callbacks;
182- callbacks.construct = [](void *obj) {
183- (void )nb::handle (static_cast <PyObject *>(obj)).inc_ref ();
184- };
185- callbacks.destruct = [](void *obj) {
186- (void )nb::handle (static_cast <PyObject *>(obj)).dec_ref ();
187- };
188- callbacks.initialize = nullptr ;
189- callbacks.clone = [](void *) -> void * {
190- throw std::runtime_error (" Cloning Python passes not supported" );
191- };
192- callbacks.run = [](MlirOperation op, MlirExternalPass pass,
193- void *userData) {
194- nb::handle (static_cast <PyObject *>(userData))(op, pass);
195- };
226+ auto callbacks = createExternalPassCallbacksForPythonCallable ();
196227 auto externalPass = mlirCreateExternalPass (
197228 passID, mlirStringRefCreate (name->data (), name->length ()),
198229 mlirStringRefCreate (argument.data (), argument.length ()),
0 commit comments