Skip to content

Conversation

@PragmaTwice
Copy link
Member

This PR brings a new function register_pass to MLIR python bindings, which can register a python-defined pass into MLIR (corresponding to mlir::registerPass in C++).

An example:

def demo_pass(op, pass_):
    pass # do something

register_pass("my-python-demo-pass", demo_pass)

pm = PassManager('any')
pm.add("my-python-demo-pass, some-cpp-defined-pass ...")
pm.run(..)

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Sep 10, 2025

While simple test cases work well for the current register_pass (check here), there are still something we should consider:

First, currently the python callable is actually "shared" between all pipelines (in all instances of the pass registered with the callable). For example:

def my_pass(op, pass_):
    ...

register_pass("my-pass", my_pass)

pm1.add("my-pass, ..., my-pass")
...
pm2.add("my-pass")

In the example above, the object my_pass is shared between the three instances of that pass in these two pass managers. For simple passes (e.g. passes without states) it should be fine, but for complicated passes, e.g.

class AdvancedPass:
    def __init__(self):
        self.count = 0
    def __call__(self, op, pass_):
        self.count += 1

In this example the self.count will be increased to 2 and 3 rather than always be 1 in multiple instances of this pass. And this is not expected for most users and not aligned with mlir::Pass in C++ since each time we created a new ExternalPass by the factory lambda. Hence from my side, the python object need to be deepcopied before constructing the ExternalPass.

Second, the lifetime of this python object should be quite long (i.e. all time of the python program) since we can create such pass anywhere we want after we register it into the system. (and now we didn't extend its lifetime via something like inc_ref)

@PragmaTwice PragmaTwice marked this pull request as ready for review September 10, 2025 13:07
@llvmbot llvmbot added the mlir label Sep 10, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

This PR brings a new function register_pass to MLIR python bindings, which can register a python-defined pass into MLIR (corresponding to mlir::registerPass in C++).

An example:

def demo_pass(op, pass_):
    pass # do something

register_pass("my-python-demo-pass", demo_pass)

pm = PassManager('any')
pm.add("my-python-demo-pass, some-cpp-defined-pass ...")
pm.run(..)

Full diff: https://github.com/llvm/llvm-project/pull/157850.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Pass.h (+7)
  • (modified) mlir/lib/Bindings/Python/Pass.cpp (+46-15)
  • (modified) mlir/lib/CAPI/IR/Pass.cpp (+26)
  • (modified) mlir/test/python/python_pass.py (+42-2)
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 0d2e19ee7fb0a..1328b4de5d4eb 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -184,6 +184,13 @@ MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
     intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
     MlirExternalPassCallbacks callbacks, void *userData);
 
+MLIR_CAPI_EXPORTED void
+mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
+                         MlirStringRef argument, MlirStringRef description,
+                         MlirStringRef opName, intptr_t nDependentDialects,
+                         MlirDialectHandle *dependentDialects,
+                         MlirExternalPassCallbacks callbacks, void *userData);
+
 /// This signals that the pass has failed. This is only valid to call during
 /// the `run` callback of `MlirExternalPassCallbacks`.
 /// See Pass::signalPassFailure().
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 47ef5d8e9dd3b..558ab6a43d87b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -52,6 +52,24 @@ class PyPassManager {
   MlirPassManager passManager;
 };
 
+MlirExternalPassCallbacks createExternalPassCallbacksForPythonCallable() {
+  MlirExternalPassCallbacks callbacks;
+  callbacks.construct = [](void *obj) {
+    (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+  };
+  callbacks.destruct = [](void *obj) {
+    (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+  };
+  callbacks.initialize = nullptr;
+  callbacks.clone = [](void *) -> void * {
+    throw std::runtime_error("Cloning Python passes not supported");
+  };
+  callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) {
+    nb::handle(static_cast<PyObject *>(userData))(op, pass);
+  };
+  return callbacks;
+}
+
 } // namespace
 
 /// Create the `mlir.passmanager` here.
@@ -63,6 +81,33 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
       .def("signal_pass_failure",
            [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
 
+  //----------------------------------------------------------------------------
+  // Mapping of register_pass
+  //----------------------------------------------------------------------------
+  m.def(
+      "register_pass",
+      [](const std::string &argument, const nb::callable &run,
+         std::optional<std::string> &name, const std::string &description,
+         const std::string &opName) {
+        if (!name.has_value()) {
+          name =
+              nb::cast<std::string>(nb::borrow<nb::str>(run.attr("__name__")));
+        }
+        MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+        MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+        auto callbacks = createExternalPassCallbacksForPythonCallable();
+        mlirRegisterExternalPass(
+            passID, mlirStringRefCreate(name->data(), name->length()),
+            mlirStringRefCreate(argument.data(), argument.length()),
+            mlirStringRefCreate(description.data(), description.length()),
+            mlirStringRefCreate(opName.data(), opName.size()),
+            /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, callbacks,
+            /*userData*/ run.ptr());
+      },
+      "argument"_a, "run"_a, "name"_a.none() = nb::none(),
+      "description"_a.none() = "", "op_name"_a.none() = "",
+      "Register a python-defined pass.");
+
   //----------------------------------------------------------------------------
   // Mapping of the top-level PassManager
   //----------------------------------------------------------------------------
@@ -178,21 +223,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
             MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
             MlirTypeID passID =
                 mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
-            MlirExternalPassCallbacks callbacks;
-            callbacks.construct = [](void *obj) {
-              (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
-            };
-            callbacks.destruct = [](void *obj) {
-              (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
-            };
-            callbacks.initialize = nullptr;
-            callbacks.clone = [](void *) -> void * {
-              throw std::runtime_error("Cloning Python passes not supported");
-            };
-            callbacks.run = [](MlirOperation op, MlirExternalPass pass,
-                               void *userData) {
-              nb::handle(static_cast<PyObject *>(userData))(op, pass);
-            };
+            auto callbacks = createExternalPassCallbacksForPythonCallable();
             auto externalPass = mlirCreateExternalPass(
                 passID, mlirStringRefCreate(name->data(), name->length()),
                 mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index b0a6ec1ace3cc..8924f6d9ec6a9 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -216,6 +216,32 @@ MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
       userData)));
 }
 
+void mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
+                              MlirStringRef argument, MlirStringRef description,
+                              MlirStringRef opName, intptr_t nDependentDialects,
+                              MlirDialectHandle *dependentDialects,
+                              MlirExternalPassCallbacks callbacks,
+                              void *userData) {
+  // here we clone these arguments as owned and pass them to
+  // the lambda as copies to avoid dangling refs,
+  // since the lambda below lives longer than the current function
+  std::string nameStr = unwrap(name).str();
+  std::string argumentStr = unwrap(argument).str();
+  std::string descriptionStr = unwrap(description).str();
+  std::string opNameStr = unwrap(opName).str();
+  std::vector<MlirDialectHandle> dependentDialectVec(
+      dependentDialects, dependentDialects + nDependentDialects);
+
+  mlir::registerPass([passID, nameStr, argumentStr, descriptionStr, opNameStr,
+                      dependentDialectVec, callbacks, userData] {
+    return std::unique_ptr<mlir::Pass>(new mlir::ExternalPass(
+        unwrap(passID), nameStr, argumentStr, descriptionStr,
+        opNameStr.length() > 0 ? std::optional<StringRef>(opNameStr)
+                               : std::nullopt,
+        dependentDialectVec, callbacks, userData));
+  });
+}
+
 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
   unwrap(pass)->signalPassFailure();
 }
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 50c42102f66d3..0fbd96ec71ddc 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -89,7 +89,7 @@ def __call__(self, op, pass_):
 
         # test signal_pass_failure
         def custom_pass_that_fails(op, pass_):
-            print("hello from pass that fails")
+            print("hello from pass that fails", file=sys.stderr)
             pass_.signal_pass_failure()
 
         pm = PassManager("any")
@@ -99,4 +99,44 @@ def custom_pass_that_fails(op, pass_):
         try:
             pm.run(module)
         except Exception as e:
-            print(f"caught exception: {e}")
+            print(f"caught exception: {e}", file=sys.stderr)
+
+
+# CHECK-LABEL: TEST: testRegisterPass
+@run
+def testRegisterPass():
+    with Context():
+        pdl_module = make_pdl_module()
+        frozen = PDLModule(pdl_module).freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @add(%a: i64, %b: i64) -> i64 {
+                %sum = arith.addi %a, %b : i64
+                return %sum : i64
+              }
+            }
+        """
+        )
+
+        def custom_pass_3(op, pass_):
+            print("hello from pass 3!!!", file=sys.stderr)
+
+        def custom_pass_4(op, pass_):
+            apply_patterns_and_fold_greedily(op, frozen)
+
+        register_pass("custom-pass-one", custom_pass_3)
+        register_pass("custom-pass-two", custom_pass_4)
+
+        pm = PassManager("any")
+        pm.enable_ir_printing()
+
+        # CHECK: hello from pass 3!!!
+        # CHECK-LABEL: Dump After custom_pass_3
+        # CHECK-LABEL: Dump After custom_pass_4
+        # CHECK: arith.muli
+        # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+        # CHECK: llvm.mul
+        pm.add("custom-pass-one, custom-pass-two, convert-arith-to-llvm")
+        pm.run(module)

@rolfmorel
Copy link
Contributor

rolfmorel commented Sep 14, 2025

Thanks @PragmaTwice - this is good functionality to have!

Having said that, I feel the design of passes-are-functions-with-metadata is showing its limitations. I note the following "strikes against" (and contrast them with a class-oriented design):

  1. needing to pass in an object to call a method on to signal pass failure - [MLIR][Python] Add the ability to signal pass failures in python-defined passes #157613
    • signal_pass_failure could just be a method that an instance of a Pass calls on its base class
  2. needing to pass dependent dialects as metadata - [MLIR][Python] Support Python-defined passes in MLIR #156000 (comment)
    • this could just be a static property of the class that inherits Pass. Registration could use this static property.
  3. the __call__-implementing objects being shared among pass managers (unless you go through contortions) - this PR.
    • if the registration method took a class definition (or some other factory for objects), there would be a lack of surprise regarding objects/instances not being shared among pass managers. That is, it would be the same as on the C++ side.

I know you and @makslevental have already looked at alternatives, in particular, a class-based design and noted issues. Regarding @makslevental's comment:

you have to create the object in C++ in order for the unique_ptr magic to work. So that won't blend with "subclass and instantiate a nanobind class in Python".

Could it work if we pass the class object to the bindings and have the right callbacks instantiate the class from the C++ side? Presumably that would mean C++-code gets to manage the lifetimes of the instances. Maybe taking a weakref to the class definition would be enough (as the class definition should normally live until the end of the Python runtime).

@makslevental
Copy link
Contributor

makslevental commented Sep 14, 2025

Presumably that would mean C++-code gets to manage the lifetimes of the instances.

You will never be able to skirt the issue that Python leaks references like a sieve. Already "functions with metadata" are classes because of closures - there's no difference between "function that has tons of stuff in its closure" and "class with data" as far as lifetime management is concerned. There is no real way for C++ to control the lifetime of anything created in Python - not functions with data in their closures and definitely not data owned by classes which will have all sorts of their own references. The reason I strongly advocated (and will continue to advocate) functions is because at least there's some signal to the user they're doing questionable things wrt lifetime (ie they've used things in the closure).

the call-implementing objects being share among pass managers (unless you go through contortions) - this PR.

The first two strikes are style issues (OOP vs free-function) not strikes in my opinion. This one is real and why I think registration should just not be a thing yet. Like I said in the first PR - we don't have pipelines in Python except as strings so this is "putting the cart before the horse".

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Sep 15, 2025

signal_pass_failure could just be a method that an instance of a Pass calls on its base class

Ahh if you look at the first edition of this PR, you will find that it can be a method bound to self (although not as a method from base class). It is then refactored since the design is not so straightforward (attaching additional attributes to a callable).

The "base class" form looks good at first, but you know that, the "base class" cannot be a real mlir::Pass-derived class or MlirPass C API (the Pass will be passed as std::unique_ptr<Pass> to the pass manager and delete passPtr will be called by the pass manager for destructing), and instead it can only be a wrapper class so finally it is not so attractive as I initially thought. (I think all lifetime related issues still exist in that form.)

this could just be a static property of the class that inherits Pass

Yup. I think the advantage of class is that we can have a "factory object" to generate pass objects (instead of deepcopying the callable), and also static properties/methods that may benefit users. But in the current design we can also have non-static properties in the callable objects and we can use these as the input of pass descriptions/argument/name.. .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants