Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);

MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
MlirOperation op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure which name is suitable.. So the current name is just a placeholder. And modifying mlirApplyPatternsAndFoldGreedily seems a breaking change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a breaking change, but not sure how widely and also C API is really best effort stable (meaning, we try not to break it). I probably should have included a Module suffix on the original, and then this could have gone without. How many usages in the wild can you find of this call?

(Nit: ForOp suffix makes me think it's related to ForOp).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I perform a simple check on public repos via GitHub search (https://github.com/search?q=mlirApplyPatternsAndFoldGreedily+&type=code) and there seems to be some use cases of this API so I'm not sure if we should change it or not (maybe leave to maintainers for this decision : ).

Currently I'm going to rename it to mlirApplyPatternsAndFoldGreedilyWithOp to avoid such confusion with ForOp as you mentioned : )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please extract this change in an independent PR? This is largely unrelated (your test python pass could just do another kind of rewrite right now).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed in ca80408.

Copy link
Member Author

@PragmaTwice PragmaTwice Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please extract this change in an independent PR? This is largely unrelated (your test python pass could just do another kind of rewrite right now).

Ahh got it.

I think I can do it via op/block/region/walk/erase APIs directly, but it maybe a little ugly since the normal rewriter API is not ported to python yet AFAIK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - no rewriter API in Python but you can use relevant methods on ops directly, etc. (Here's one example: https://github.com/libxsmm/tpp-mlir/pull/1064/files#diff-7aa62724b21b998da9cf032da6e6b77bbdb664258a5dca850f9509a5459646f6R110-R118 - @makslevental has more.)

That should be enough to demonstrate that your test pass is actually running and, if kept simple, will not be that ugly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just use apply_patterns_and_fold_greedily here and then add apply_patterns_and_fold_greedily_with_op in a follow-up (the existing test passes just fine with apply_patterns_and_fold_greedily).


MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passModule);
populatePassSubmodule(passModule);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that pass is a keyword in python so that name like mlir.pass.Pass doesn't work. 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, pass vs pass :-) I don't have a good suggestion given how established that is MLIR/compilers and Python side ... MlirPass or PyPass or pydefpass would be most obvious.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the full name of the base class is mlir.passmanager.Pass. Is it good enough or we'd better to rename it with another module name?

Copy link
Contributor

@rolfmorel rolfmorel Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that PEP8 recommends appending a underscore in case of names of arguments / attributes clashing with reserved keywords, I feel mlir.pass_.Pass is an option as it is not too surprising for me personally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think putting a trailing underscore in a namespace is a little awkward. How about mlir.passes? or mlir.passinfra? or I dunno something like that that I'm failing to come up with right now...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. For now, I will rename it to mlir.passes.Pass. Let me know if anyone think of a better name : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 01e68c5.

}
101 changes: 101 additions & 0 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include "IRModule.h"
#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Nanobind.h"

#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
#include "nanobind/trampoline.h"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empty line is to workaround clang-format: otherwise it will report that "mlir-c/Bindings/Python/Interop.h" should be put before "mlir/Bindings/Python/Nanobind.h" : )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should do

// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind.
// clang-format on

(the ON WINDOWS isn't currently there but it should be because I discovered recently that it's in fact only windows that this matters for).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh thank you for your suggestion. It is done in c8c2fae.


namespace nb = nanobind;
using namespace nb::literals;
Expand All @@ -20,6 +22,79 @@ using namespace mlir::python;

namespace {

// A base class for defining passes in Python
// Users are expected to subclass this and implement the `run` method, e.g.
// ```
// class MyPass(mlir.passmanager.Pass):
// def run(self, operation):
// # do something with operation
// pass
// ```
class PyPassBase {
public:
PyPassBase(std::string name, std::string argument, std::string description,
std::string opName)
: name(std::move(name)), argument(std::move(argument)),
description(std::move(description)), opName(std::move(opName)) {
callbacks.construct = [](void *obj) {};
callbacks.destruct = [](void *obj) {
nb::handle(static_cast<PyObject *>(obj)).dec_ref();
};
callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
auto handle = nb::handle(static_cast<PyObject *>(obj));
nb::cast<PyPassBase *>(handle)->run(op);
};
callbacks.clone = [](void *obj) -> void * {
nb::object copy = nb::module_::import_("copy");
nb::object deepcopy = copy.attr("deepcopy");
return deepcopy(obj).release().ptr();
};
callbacks.initialize = nullptr;
Copy link
Member Author

@PragmaTwice PragmaTwice Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize is not ported to python side yet.

}

// this method should be overridden by subclasses in Python.
virtual void run(MlirOperation op) = 0;

virtual ~PyPassBase() = default;

// Make an MlirPass instance on-the-fly that wraps this object.
// Note that passmanager will take the ownership of the returned
// object and release it when appropriate.
// Also, `*this` must remain alive as long as the returned object is alive.
MlirPass make() {
auto *obj = nb::find(this).release().ptr();
Copy link
Member Author

@PragmaTwice PragmaTwice Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About lifetime of the python object: here we increase the ref count by nb::find (and by release() we avoid decrease the count here), and when the ExternalPass is destructed, callback.destructor is called so that dec_ref() is called for this object.

return mlirCreateExternalPass(
mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
mlirStringRefCreate(argument.data(), argument.length()),
mlirStringRefCreate(description.data(), description.length()),
mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dependent dialects are not yet supported.

Copy link
Contributor

@rolfmorel rolfmorel Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is relevant for the discussion on whether a function as a pass suffices. I think dependent dialects demonstrate that, at least when it comes to registering them, passes are not just functions. They also have a bit of metadata.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, as you can see we pass 0, nullptr here, and it can be supported if we change them to nDialects, dialectListPtr. So I don't think it is a proof of a problem in this design. (it is supported by ExternalPass)

A function in python is just an object with __call__ method implemented. So I think we can bind any information with this object by assigning them to attrs of this object, if needed. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PragmaTwice and I already discussed this offline: to support dependent dialects we need to add nanobindings for MlirDialectHandle. Easy to do since it's just an opaque handle. If it's a high-priority I can quickly do that as a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With Python we can monkey patch arbitrary things arbitrarily. Doesn't mean that's a good idea / good to base your design on. As such, binding new attributes to an object is not the way to go.

If you and @makslevental still prefer just passing a function, then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass. Alternatively, have both a Pass class to inherit with a dependent_dialects property (and potentially with __call__ implemented as a call to run()) which the registration API automatically uses and have a mechanism for wrapping up a callback. The wrapping-up mechanism could also be a factory method on Pass though maybe that doesn't make things simpler when it comes to lifetime management.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then this metadata could (/has to) be passed as arguments to the API that registers the callback as a pass

Yes that's exactly the follow-up I had in mind.

The wrapping-up mechanism could also be a factory method on Pass though maybe that doesn't make things simpler when it comes to lifetime management.

It's the lifetime management that is the issue - the C++ APIs expect ownership of the Pass object. But there's simply no way to express "unique ownership" in Python. That's why I rewrote @PragmaTwice's original PR (which isn't very different from what you propose) to only manage the lifetime of a single Python object - the run callback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay maybe we can use nanobind itself to get the right semantics - let me try again using https://nanobind.readthedocs.io/en/latest/ownership.html#unique-pointers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah this won't blend - you have to create the object in C++ in order for the unique_ptr magic to work. So that won't blend with "subclass and instantiate a nanobind class in Python".

callbacks, obj);
}

const std::string &getName() const { return name; }
const std::string &getArgument() const { return argument; }
const std::string &getDescription() const { return description; }
const std::string &getOpName() const { return opName; }

private:
MlirExternalPassCallbacks callbacks;

std::string name;
std::string argument;
std::string description;
std::string opName;
};

// A trampoline class upon PyPassBase.
// Refer to
// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
class PyPass : PyPassBase {
public:
NB_TRAMPOLINE(PyPassBase, 1);

void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
};

/// Owning Wrapper around a PassManager.
class PyPassManager {
public:
Expand Down Expand Up @@ -52,6 +127,26 @@ class PyPassManager {

} // namespace

void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the Python-defined Pass interface
//----------------------------------------------------------------------------
nb::class_<PyPassBase, PyPass>(m, "Pass")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this backwards? cf <PyOperation, PyOperationBase>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh this is a quite interesting case that I just did some research on : )

Refer to the nanobind documentation (https://nanobind.readthedocs.io/en/latest/api_core.html#_CPPv4I0DpEN8nanobind6class_E):

template<typename T, typename ...Ts>
class class_ : public object

The variable length parameter Ts is optional and can be used to specify the base class of T and/or an alias needed to realize trampoline classes.

So in the case you mentioned (<PyOperation, PyOperationBase>), the parameter Ts (PyOperationBase) is a base class of T (PyOperation); and for the case here (<PyPassBase, PyPass>), the Ts (PyPass) is a trampoline class of T (PyPassBase). So I think both of them is correct here.

For example, in the documentation of trampoline classes (https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python), we can see such an instance of nb::class_:

nb::class_<Dog, PyDog /* <-- trampoline */>(m, "Dog")

And here PyDog is a derived class of Dog but also a trampoline class of Dog : )

.def(nb::init<std::string, std::string, std::string, std::string>(),
"name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
"op_name"_a = "", "Create a new Pass.")
.def("run", &PyPassBase::run, "operation"_a,
"Run the pass on the provided operation.")
.def_prop_ro("name",
[](const PyPassBase &self) { return self.getName(); })
.def_prop_ro("argument",
[](const PyPassBase &self) { return self.getArgument(); })
.def_prop_ro("description",
[](const PyPassBase &self) { return self.getDescription(); })
.def_prop_ro("op_name",
[](const PyPassBase &self) { return self.getOpName(); });
}

/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
Expand Down Expand Up @@ -157,6 +252,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"add",
[](PyPassManager &passManager, PyPassBase &pass) {
mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
},
"pass"_a, "Add a python-defined pass to the pass manager.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace mlir {
namespace python {

void populatePassManagerSubmodule(nanobind::module_ &m);
void populatePassSubmodule(nanobind::module_ &m);

} // namespace python
} // namespace mlir
Expand Down
31 changes: 21 additions & 10 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](MlirModule module, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw nb::value_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "
"results.");
"apply_patterns_and_fold_greedily",
[](MlirModule module, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw nb::value_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily_for_op",
[](MlirOperation op, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedilyForOp(op, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw nb::value_error("pattern application failed to converge");
},
"op"_a, "set"_a,
"Applys the given patterns to the given op greedily while folding "
"results.");
}
7 changes: 7 additions & 0 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}

MlirLogicalResult
mlirApplyPatternsAndFoldGreedilyForOp(MlirOperation op,
MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig) {
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/python/pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import gc, sys
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import pdl
from mlir.rewrite import *


def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()


def run(f):
log("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0


def make_pdl_module():
with Location.unknown():
pdl_module = Module.create()
with InsertionPoint(pdl_module.body):
# Change all arith.addi with index types to arith.muli.
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
def pat():
# Match arith.addi with index types.
index_type = pdl.TypeOp(IndexType.get())
operand0 = pdl.OperandOp(index_type)
operand1 = pdl.OperandOp(index_type)
op0 = pdl.OperationOp(
name="arith.addi", args=[operand0, operand1], types=[index_type]
)

# Replace the matched op with arith.muli.
@pdl.rewrite()
def rew():
newOp = pdl.OperationOp(
name="arith.muli", args=[operand0, operand1], types=[index_type]
)
pdl.ReplaceOp(op0, with_op=newOp)

return pdl_module


# CHECK-LABEL: TEST: testCustomPass
@run
def testCustomPass():
with Context():
pdl_module = make_pdl_module()

class CustomPass(Pass):
def __init__(self):
super().__init__("CustomPass", op_name="builtin.module")

def run(self, m):
frozen = PDLModule(pdl_module).freeze()
apply_patterns_and_fold_greedily_for_op(m, frozen)

module = ModuleOp.parse(
r"""
module {
func.func @add(%a: index, %b: index) -> index {
%sum = arith.addi %a, %b : index
return %sum : index
}
}
"""
)

# CHECK-LABEL: Dump After CustomPass
# CHECK: arith.muli
pm = PassManager("any")
pm.enable_ir_printing()
pm.add(CustomPass())
pm.run(module)
Loading