Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 1c09d21

Browse files
PragmaTwicecnb.bsD2OPwAgEAmakslevental
authored
[MLIR][Python] Support Python-defined passes in MLIR (#156000)
It closes #155996. This PR added a method `add(callable, ..)` to `mlir.passmanager.PassManager` to accept a callable object for defining passes in the Python side. This is a simple example of a Python-defined pass. ```python from mlir.passmanager import PassManager def demo_pass_1(op): # do something with op pass class DemoPass: def __init__(self, ...): pass def __call__(op): # do something pass demo_pass_2 = DemoPass(..) pm = PassManager('any', ctx) pm.add(demo_pass_1) pm.add(demo_pass_2) pm.add("registered-passes") pm.run(..) ``` --------- Co-authored-by: cnb.bsD2OPwAgEA <[email protected]> Co-authored-by: Maksim Levental <[email protected]>
1 parent 68f631a commit 1c09d21

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) {
136136
populateRewriteSubmodule(rewriteModule);
137137

138138
// Define and populate PassManager submodule.
139-
auto passModule =
139+
auto passManagerModule =
140140
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
141-
populatePassManagerSubmodule(passModule);
141+
populatePassManagerSubmodule(passManagerModule);
142142
}

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include "IRModule.h"
1212
#include "mlir-c/Pass.h"
13+
// clang-format off
1314
#include "mlir/Bindings/Python/Nanobind.h"
1415
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
16+
// clang-format on
1517

1618
namespace nb = nanobind;
1719
using namespace nb::literals;
@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
157159
"pipeline"_a,
158160
"Add textual pipeline elements to the pass manager. Throws a "
159161
"ValueError if the pipeline can't be parsed.")
162+
.def(
163+
"add",
164+
[](PyPassManager &passManager, const nb::callable &run,
165+
std::optional<std::string> &name, const std::string &argument,
166+
const std::string &description, const std::string &opName) {
167+
if (!name.has_value()) {
168+
name = nb::cast<std::string>(
169+
nb::borrow<nb::str>(run.attr("__name__")));
170+
}
171+
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
172+
MlirTypeID passID =
173+
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
174+
MlirExternalPassCallbacks callbacks;
175+
callbacks.construct = [](void *obj) {
176+
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
177+
};
178+
callbacks.destruct = [](void *obj) {
179+
(void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
180+
};
181+
callbacks.initialize = nullptr;
182+
callbacks.clone = [](void *) -> void * {
183+
throw std::runtime_error("Cloning Python passes not supported");
184+
};
185+
callbacks.run = [](MlirOperation op, MlirExternalPass,
186+
void *userData) {
187+
nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
188+
};
189+
auto externalPass = mlirCreateExternalPass(
190+
passID, mlirStringRefCreate(name->data(), name->length()),
191+
mlirStringRefCreate(argument.data(), argument.length()),
192+
mlirStringRefCreate(description.data(), description.length()),
193+
mlirStringRefCreate(opName.data(), opName.size()),
194+
/*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
195+
callbacks, /*userData*/ run.ptr());
196+
mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
197+
},
198+
"run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
199+
"description"_a.none() = "", "op_name"_a.none() = "",
200+
"Add a python-defined pass to the pass manager.")
160201
.def(
161202
"run",
162203
[](PyPassManager &passManager, PyOperationBase &op) {

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
145145
: Pass(passID, opName), id(passID), name(name), argument(argument),
146146
description(description), dependentDialects(dependentDialects),
147147
callbacks(callbacks), userData(userData) {
148-
callbacks.construct(userData);
148+
if (callbacks.construct)
149+
callbacks.construct(userData);
149150
}
150151

151-
~ExternalPass() override { callbacks.destruct(userData); }
152+
~ExternalPass() override {
153+
if (callbacks.destruct)
154+
callbacks.destruct(userData);
155+
}
152156

153157
StringRef getName() const override { return name; }
154158
StringRef getArgument() const override { return argument; }

0 commit comments

Comments
 (0)