Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit ab0b696

Browse files
authored
[MLIR][Python] Add the ability to signal pass failures in python-defined passes (#157613)
This is a follow-up PR for #156000. In this PR we add the ability to signal pass failures (`signal_pass_failure()`) in python-defined passes. To achieve this, we expose `MlirExternalPass` via `nb::class_` with a method `signal_pass_failure()`, and the callable passed to `pm.add(..)` now accepts two arguments (`op: MlirOperation, pass_: MlirExternalPass`). For example: ```python def custom_pass_that_fails(op, pass_): if some_condition: pass_.signal_pass_failure() # do something ```
1 parent 93f4358 commit ab0b696

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ class PyPassManager {
5656

5757
/// Create the `mlir.passmanager` here.
5858
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
59+
//----------------------------------------------------------------------------
60+
// Mapping of MlirExternalPass
61+
//----------------------------------------------------------------------------
62+
nb::class_<MlirExternalPass>(m, "ExternalPass")
63+
.def("signal_pass_failure",
64+
[](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
65+
5966
//----------------------------------------------------------------------------
6067
// Mapping of the top-level PassManager
6168
//----------------------------------------------------------------------------
@@ -182,9 +189,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
182189
callbacks.clone = [](void *) -> void * {
183190
throw std::runtime_error("Cloning Python passes not supported");
184191
};
185-
callbacks.run = [](MlirOperation op, MlirExternalPass,
192+
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
186193
void *userData) {
187-
nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
194+
nb::handle(static_cast<PyObject *>(userData))(op, pass);
188195
};
189196
auto externalPass = mlirCreateExternalPass(
190197
passID, mlirStringRefCreate(name->data(), name->length()),

0 commit comments

Comments
 (0)