Skip to content

Commit 09870a6

Browse files
committed
[MLIR][Python] Add a function to register python-defined passes
1 parent 735522a commit 09870a6

File tree

4 files changed

+121
-17
lines changed

4 files changed

+121
-17
lines changed

mlir/include/mlir-c/Pass.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
184184
intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
185185
MlirExternalPassCallbacks callbacks, void *userData);
186186

187+
MLIR_CAPI_EXPORTED void
188+
mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
189+
MlirStringRef argument, MlirStringRef description,
190+
MlirStringRef opName, intptr_t nDependentDialects,
191+
MlirDialectHandle *dependentDialects,
192+
MlirExternalPassCallbacks callbacks, void *userData);
193+
187194
/// This signals that the pass has failed. This is only valid to call during
188195
/// the `run` callback of `MlirExternalPassCallbacks`.
189196
/// See Pass::signalPassFailure().

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ class PyPassManager {
5252
MlirPassManager passManager;
5353
};
5454

55+
MlirExternalPassCallbacks createExternalPassCallbacksForPythonCallable() {
56+
MlirExternalPassCallbacks callbacks;
57+
callbacks.construct = [](void *obj) {
58+
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
59+
};
60+
callbacks.destruct = [](void *obj) {
61+
(void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
62+
};
63+
callbacks.initialize = nullptr;
64+
callbacks.clone = [](void *) -> void * {
65+
throw std::runtime_error("Cloning Python passes not supported");
66+
};
67+
callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) {
68+
nb::handle(static_cast<PyObject *>(userData))(op, pass);
69+
};
70+
return callbacks;
71+
}
72+
5573
} // namespace
5674

5775
/// Create the `mlir.passmanager` here.
@@ -63,6 +81,33 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
6381
.def("signal_pass_failure",
6482
[](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
6583

84+
//----------------------------------------------------------------------------
85+
// Mapping of register_pass
86+
//----------------------------------------------------------------------------
87+
m.def(
88+
"register_pass",
89+
[](const std::string &argument, const nb::callable &run,
90+
std::optional<std::string> &name, const std::string &description,
91+
const std::string &opName) {
92+
if (!name.has_value()) {
93+
name =
94+
nb::cast<std::string>(nb::borrow<nb::str>(run.attr("__name__")));
95+
}
96+
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
97+
MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
98+
auto callbacks = createExternalPassCallbacksForPythonCallable();
99+
mlirRegisterExternalPass(
100+
passID, mlirStringRefCreate(name->data(), name->length()),
101+
mlirStringRefCreate(argument.data(), argument.length()),
102+
mlirStringRefCreate(description.data(), description.length()),
103+
mlirStringRefCreate(opName.data(), opName.size()),
104+
/*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, callbacks,
105+
/*userData*/ run.ptr());
106+
},
107+
"argument"_a, "run"_a, "name"_a.none() = nb::none(),
108+
"description"_a.none() = "", "op_name"_a.none() = "",
109+
"Register a python-defined pass.");
110+
66111
//----------------------------------------------------------------------------
67112
// Mapping of the top-level PassManager
68113
//----------------------------------------------------------------------------
@@ -178,21 +223,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
178223
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
179224
MlirTypeID passID =
180225
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
181-
MlirExternalPassCallbacks callbacks;
182-
callbacks.construct = [](void *obj) {
183-
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
184-
};
185-
callbacks.destruct = [](void *obj) {
186-
(void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
187-
};
188-
callbacks.initialize = nullptr;
189-
callbacks.clone = [](void *) -> void * {
190-
throw std::runtime_error("Cloning Python passes not supported");
191-
};
192-
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
193-
void *userData) {
194-
nb::handle(static_cast<PyObject *>(userData))(op, pass);
195-
};
226+
auto callbacks = createExternalPassCallbacksForPythonCallable();
196227
auto externalPass = mlirCreateExternalPass(
197228
passID, mlirStringRefCreate(name->data(), name->length()),
198229
mlirStringRefCreate(argument.data(), argument.length()),

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,32 @@ MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
216216
userData)));
217217
}
218218

219+
void mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
220+
MlirStringRef argument, MlirStringRef description,
221+
MlirStringRef opName, intptr_t nDependentDialects,
222+
MlirDialectHandle *dependentDialects,
223+
MlirExternalPassCallbacks callbacks,
224+
void *userData) {
225+
// here we clone these arguments as owned and pass them to
226+
// the lambda as copies to avoid dangling refs,
227+
// since the lambda below lives longer than the current function
228+
std::string nameStr = unwrap(name).str();
229+
std::string argumentStr = unwrap(argument).str();
230+
std::string descriptionStr = unwrap(description).str();
231+
std::string opNameStr = unwrap(opName).str();
232+
std::vector<MlirDialectHandle> dependentDialectVec(
233+
dependentDialects, dependentDialects + nDependentDialects);
234+
235+
mlir::registerPass([passID, nameStr, argumentStr, descriptionStr, opNameStr,
236+
dependentDialectVec, callbacks, userData] {
237+
return std::unique_ptr<mlir::Pass>(new mlir::ExternalPass(
238+
unwrap(passID), nameStr, argumentStr, descriptionStr,
239+
opNameStr.length() > 0 ? std::optional<StringRef>(opNameStr)
240+
: std::nullopt,
241+
dependentDialectVec, callbacks, userData));
242+
});
243+
}
244+
219245
void mlirExternalPassSignalFailure(MlirExternalPass pass) {
220246
unwrap(pass)->signalPassFailure();
221247
}

mlir/test/python/python_pass.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __call__(self, op, pass_):
8989

9090
# test signal_pass_failure
9191
def custom_pass_that_fails(op, pass_):
92-
print("hello from pass that fails")
92+
print("hello from pass that fails", file=sys.stderr)
9393
pass_.signal_pass_failure()
9494

9595
pm = PassManager("any")
@@ -99,4 +99,44 @@ def custom_pass_that_fails(op, pass_):
9999
try:
100100
pm.run(module)
101101
except Exception as e:
102-
print(f"caught exception: {e}")
102+
print(f"caught exception: {e}", file=sys.stderr)
103+
104+
105+
# CHECK-LABEL: TEST: testRegisterPass
106+
@run
107+
def testRegisterPass():
108+
with Context():
109+
pdl_module = make_pdl_module()
110+
frozen = PDLModule(pdl_module).freeze()
111+
112+
module = ModuleOp.parse(
113+
r"""
114+
module {
115+
func.func @add(%a: i64, %b: i64) -> i64 {
116+
%sum = arith.addi %a, %b : i64
117+
return %sum : i64
118+
}
119+
}
120+
"""
121+
)
122+
123+
def custom_pass_3(op, pass_):
124+
print("hello from pass 3!!!", file=sys.stderr)
125+
126+
def custom_pass_4(op, pass_):
127+
apply_patterns_and_fold_greedily(op, frozen)
128+
129+
register_pass("custom-pass-one", custom_pass_3)
130+
register_pass("custom-pass-two", custom_pass_4)
131+
132+
pm = PassManager("any")
133+
pm.enable_ir_printing()
134+
135+
# CHECK: hello from pass 3!!!
136+
# CHECK-LABEL: Dump After custom_pass_3
137+
# CHECK-LABEL: Dump After custom_pass_4
138+
# CHECK: arith.muli
139+
# CHECK-LABEL: Dump After ArithToLLVMConversionPass
140+
# CHECK: llvm.mul
141+
pm.add("custom-pass-one, custom-pass-two, convert-arith-to-llvm")
142+
pm.run(module)

0 commit comments

Comments
 (0)