Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);

MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
MlirOperation op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);

MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passModule);
populatePassSubmodule(passModule);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that pass is a keyword in python so that name like mlir.pass.Pass doesn't work. 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, pass vs pass :-) I don't have a good suggestion given how established that is MLIR/compilers and Python side ... MlirPass or PyPass or pydefpass would be most obvious.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the full name of the base class is mlir.passmanager.Pass. Is it good enough or we'd better to rename it with another module name?

Copy link
Contributor

@rolfmorel rolfmorel Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that PEP8 recommends appending a underscore in case of names of arguments / attributes clashing with reserved keywords, I feel mlir.pass_.Pass is an option as it is not too surprising for me personally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think putting a trailing underscore in a namespace is a little awkward. How about mlir.passes? or mlir.passinfra? or I dunno something like that that I'm failing to come up with right now...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. For now, I will rename it to mlir.passes.Pass. Let me know if anyone think of a better name : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 01e68c5.

}
103 changes: 103 additions & 0 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include "IRModule.h"
#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Nanobind.h"

#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "nanobind/trampoline.h"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empty line is to workaround clang-format: otherwise it will report that "mlir-c/Bindings/Python/Interop.h" should be put before "mlir/Bindings/Python/Nanobind.h" : )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should do

// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
// clang-format on

(the ON WINDOWS isn't currently there but it should be because I discovered recently that it's in fact only windows that this matters for).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh thank you for your suggestion. It is done in c8c2fae.


namespace nb = nanobind;
using namespace nb::literals;
Expand All @@ -20,6 +22,81 @@ using namespace mlir::python;

namespace {

// A base class for defining passes in Python
// Users are expected to subclass this and implement the `run` method, e.g.
// ```
// class MyPass(mlir.passmanager.Pass):
// def __init__(self):
// super().__init__("MyPass", ..)
// # other init stuff..
// def run(self, operation):
// # do something with operation..
// pass
// ```
class PyPassBase {
public:
PyPassBase(std::string name, std::string argument, std::string description,
std::string opName)
: name(std::move(name)), argument(std::move(argument)),
description(std::move(description)), opName(std::move(opName)) {
callbacks.construct = [](void *obj) {};
callbacks.destruct = [](void *obj) {
nb::handle(static_cast<PyObject *>(obj)).dec_ref();
};
callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
auto handle = nb::handle(static_cast<PyObject *>(obj));
nb::cast<PyPassBase *>(handle)->run(op);
};
callbacks.clone = [](void *obj) -> void * {
nb::object copy = nb::module_::import_("copy");
nb::object deepcopy = copy.attr("deepcopy");
return deepcopy(obj).release().ptr();
};
callbacks.initialize = nullptr;
Copy link
Member Author

@PragmaTwice PragmaTwice Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize is not ported to python side yet.

}

// this method should be overridden by subclasses in Python.
virtual void run(MlirOperation op) = 0;

virtual ~PyPassBase() = default;

// Make an MlirPass instance on-the-fly that wraps this object.
// Note that passmanager will take the ownership of the returned
// object and release it when appropriate.
MlirPass make() {
Copy link
Member Author

@PragmaTwice PragmaTwice Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that in this design I add a new class PyPassBase and lazily construct the MlirPass while adding into the pass manager instead of just wrapping MlirPass as a python object is that: the lifetime of MlirPass is a bit tricky.

If we don't add it into the pass manager, the owership of MlirPass should be in our hand (and we are responsible to release it finally), and after adding into a pass manager, the ownership is transfered into the manager, and delete myPassPtr will be called when the manager is destructed on the C++ side. So, the ownership state of MlirPass is changed between and after adding to pass manager.

By lazily constructing the MlirPass only when adding to a pass manager, we can avoid to expose the MlirPass object into the python world so that we don't need to care about how to interoperate between c++ delete passPtr and python object ref-count. And also we can ensure that the lifetime of the MlirPass is always handled by the pass manager.

This approach might seem a bit roundabout but works. I’m also open to trying other methods if there are any : ) (Maybe I can check how the Region/Block binding APIs work since the situation is similiar? Not sure yet.) WDYT?

auto *obj = nb::find(this).release().ptr();
Copy link
Member Author

@PragmaTwice PragmaTwice Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About lifetime of the python object: here we increase the ref count by nb::find (and by release() we avoid decrease the count here), and when the ExternalPass is destructed, callback.destructor is called so that dec_ref() is called for this object.

return mlirCreateExternalPass(
mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
mlirStringRefCreate(argument.data(), argument.length()),
mlirStringRefCreate(description.data(), description.length()),
mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dependent dialects are not yet supported.

Copy link
Contributor

@rolfmorel rolfmorel Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is relevant for the discussion on whether a function as a pass suffices. I think dependent dialects demonstrate that, at least when it comes to registering them, passes are not just functions. They also have a bit of metadata.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, as you can see we pass 0, nullptr here, and it can be supported if we change them to nDialects, dialectListPtr. So I don't think it is a proof of a problem in this design. (it is supported by ExternalPass)

A function in python is just an object with __call__ method implemented. So I think we can bind any information with this object by assigning them to attrs of this object, if needed. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PragmaTwice and I already discussed this offline: to support dependent dialects we need to add nanobindings for MlirDialectHandle. Easy to do since it's just an opaque handle. If it's a high-priority I can quickly do that as a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With Python we can monkey patch arbitrary things arbitrarily. Doesn't mean that's a good idea / good to base your design on. As such, binding new attributes to an object is not the way to go.

If you and @makslevental still prefer just passing a function, then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass. Alternatively, have both a Pass class to inherit with a dependent_dialects property (and potentially with __call__ implemented as a call to run()) which the registration API automatically uses and have a mechanism for wrapping up a callback. The wrapping-up mechanism could also be a factory method on Pass though maybe that doesn't make things simpler when it comes to lifetime management.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass

Yes that's exactly the follow-up I had in mind.

The wrapping-up mechanism could also be a factory method on Pass though maybe that doesn't make things simpler when it comes to lifetime management.

It's the lifetime management that is the issue - the C++ APIs expect ownership of the Pass object. But there's simply no way to express "unique ownership" in Python. That's why I rewrote @PragmaTwice's original PR (which isn't very different from what you propose) to only manage the lifetime of a single Python object - the run callback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay maybe we can use nanobind itself to get the right semantics - let me try again using https://nanobind.readthedocs.io/en/latest/ownership.html#unique-pointers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah this won't blend - you have to create the object in C++ in order for the unique_ptr magic to work. So that won't blend with "subclass and instantiate a nanobind class in Python".

callbacks, obj);
}

const std::string &getName() const { return name; }
const std::string &getArgument() const { return argument; }
const std::string &getDescription() const { return description; }
const std::string &getOpName() const { return opName; }

private:
MlirExternalPassCallbacks callbacks;

std::string name;
std::string argument;
std::string description;
std::string opName;
};

// A trampoline class upon PyPassBase.
// Refer to
// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
class PyPass : PyPassBase {
public:
NB_TRAMPOLINE(PyPassBase, 1);

void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
};

/// Owning Wrapper around a PassManager.
class PyPassManager {
public:
Expand Down Expand Up @@ -52,6 +129,26 @@ class PyPassManager {

} // namespace

void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the Python-defined Pass interface
//----------------------------------------------------------------------------
nb::class_<PyPassBase, PyPass>(m, "Pass")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this backwards? cf <PyOperation, PyOperationBase>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh this is a quite interesting case that I just did some research on : )

Refer to the nanobind documentation (https://nanobind.readthedocs.io/en/latest/api_core.html#_CPPv4I0DpEN8nanobind6class_E):

template<typename T, typename ...Ts>
class class_ : public object

The variable length parameter Ts is optional and can be used to specify the base class of T and/or an alias needed to realize trampoline classes.

So in the case you mentioned (<PyOperation, PyOperationBase>), the parameter Ts (PyOperationBase) is a base class of T (PyOperation); and for the case here (<PyPassBase, PyPass>), the Ts (PyPass) is a trampoline class of T (PyPassBase). So I think both of them is correct here.

For example, in the documentation of trampoline classes (https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python), we can see such an instance of nb::class_:

nb::class_<Dog, PyDog /* <-- trampoline */>(m, "Dog")

And here PyDog is a derived class of Dog but also a trampoline class of Dog : )

.def(nb::init<std::string, std::string, std::string, std::string>(),
"name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
"op_name"_a = "", "Create a new Pass.")
.def("run", &PyPassBase::run, "operation"_a,
"Run the pass on the provided operation.")
.def_prop_ro("name",
[](const PyPassBase &self) { return self.getName(); })
.def_prop_ro("argument",
[](const PyPassBase &self) { return self.getArgument(); })
.def_prop_ro("description",
[](const PyPassBase &self) { return self.getDescription(); })
.def_prop_ro("op_name",
[](const PyPassBase &self) { return self.getOpName(); });
}

/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
Expand Down Expand Up @@ -157,6 +254,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"add",
[](PyPassManager &passManager, PyPassBase &pass) {
mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
},
"pass"_a, "Add a python-defined pass to the pass manager.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace mlir {
namespace python {

void populatePassManagerSubmodule(nanobind::module_ &m);
void populatePassSubmodule(nanobind::module_ &m);

} // namespace python
} // namespace mlir
Expand Down
31 changes: 21 additions & 10 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](MlirModule module, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw nb::value_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "
"results.");
"apply_patterns_and_fold_greedily",
[](MlirModule module, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw nb::value_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily_with_op",
[](MlirOperation op, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw nb::value_error("pattern application failed to converge");
},
"op"_a, "set"_a,
"Applys the given patterns to the given op greedily while folding "
"results.");
}
7 changes: 7 additions & 0 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}

MlirLogicalResult
mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig) {
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
83 changes: 83 additions & 0 deletions mlir/test/python/pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import gc, sys
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import pdl
from mlir.rewrite import *


def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()


def run(f):
log("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0


def make_pdl_module():
with Location.unknown():
pdl_module = Module.create()
with InsertionPoint(pdl_module.body):
# Change all arith.addi with index types to arith.muli.
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
def pat():
# Match arith.addi with index types.
i64_type = pdl.TypeOp(IntegerType.get_signless(64))
operand0 = pdl.OperandOp(i64_type)
operand1 = pdl.OperandOp(i64_type)
op0 = pdl.OperationOp(
name="arith.addi", args=[operand0, operand1], types=[i64_type]
)

# Replace the matched op with arith.muli.
@pdl.rewrite()
def rew():
newOp = pdl.OperationOp(
name="arith.muli", args=[operand0, operand1], types=[i64_type]
)
pdl.ReplaceOp(op0, with_op=newOp)

return pdl_module


# CHECK-LABEL: TEST: testCustomPass
@run
def testCustomPass():
with Context():
pdl_module = make_pdl_module()

class CustomPass(Pass):
def __init__(self):
super().__init__("CustomPass", op_name="builtin.module")

def run(self, m):
frozen = PDLModule(pdl_module).freeze()
apply_patterns_and_fold_greedily_with_op(m, frozen)

module = ModuleOp.parse(
r"""
module {
func.func @add(%a: i64, %b: i64) -> i64 {
%sum = arith.addi %a, %b : i64
return %sum : i64
}
}
"""
)

pm = PassManager("any")
pm.enable_ir_printing()

# CHECK-LABEL: Dump After CustomPass
# CHECK: arith.muli
pm.add(CustomPass())
# CHECK-LABEL: Dump After ArithToLLVMConversionPass
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)