-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][Python] Add a function to register python-defined passes #157850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
While simple test cases work well for the current First, currently the python callable is actually "shared" between all pipelines (in all instances of the pass registered with the callable). For example: def my_pass(op, pass_):
...
register_pass("my-pass", my_pass)
pm1.add("my-pass, ..., my-pass")
...
pm2.add("my-pass") In the example above, the object class AdvancedPass:
def __init__(self):
self.count = 0
def __call__(self, op, pass_):
self.count += 1 In this example the Second, the lifetime of this python object should be quite long (i.e. all time of the python program) since we can create such pass anywhere we want after we register it into the system. (and now we didn't extend its lifetime via something like |
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesThis PR brings a new function An example: def demo_pass(op, pass_):
pass # do something
register_pass("my-python-demo-pass", demo_pass)
pm = PassManager('any')
pm.add("my-python-demo-pass, some-cpp-defined-pass ...")
pm.run(..) Full diff: https://github.com/llvm/llvm-project/pull/157850.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 0d2e19ee7fb0a..1328b4de5d4eb 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -184,6 +184,13 @@ MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass(
intptr_t nDependentDialects, MlirDialectHandle *dependentDialects,
MlirExternalPassCallbacks callbacks, void *userData);
+MLIR_CAPI_EXPORTED void
+mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
+ MlirStringRef argument, MlirStringRef description,
+ MlirStringRef opName, intptr_t nDependentDialects,
+ MlirDialectHandle *dependentDialects,
+ MlirExternalPassCallbacks callbacks, void *userData);
+
/// This signals that the pass has failed. This is only valid to call during
/// the `run` callback of `MlirExternalPassCallbacks`.
/// See Pass::signalPassFailure().
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 47ef5d8e9dd3b..558ab6a43d87b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -52,6 +52,24 @@ class PyPassManager {
MlirPassManager passManager;
};
+MlirExternalPassCallbacks createExternalPassCallbacksForPythonCallable() {
+ MlirExternalPassCallbacks callbacks;
+ callbacks.construct = [](void *obj) {
+ (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
+ };
+ callbacks.destruct = [](void *obj) {
+ (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
+ };
+ callbacks.initialize = nullptr;
+ callbacks.clone = [](void *) -> void * {
+ throw std::runtime_error("Cloning Python passes not supported");
+ };
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) {
+ nb::handle(static_cast<PyObject *>(userData))(op, pass);
+ };
+ return callbacks;
+}
+
} // namespace
/// Create the `mlir.passmanager` here.
@@ -63,6 +81,33 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
.def("signal_pass_failure",
[](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+ //----------------------------------------------------------------------------
+ // Mapping of register_pass
+ //----------------------------------------------------------------------------
+ m.def(
+ "register_pass",
+ [](const std::string &argument, const nb::callable &run,
+ std::optional<std::string> &name, const std::string &description,
+ const std::string &opName) {
+ if (!name.has_value()) {
+ name =
+ nb::cast<std::string>(nb::borrow<nb::str>(run.attr("__name__")));
+ }
+ MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
+ MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
+ auto callbacks = createExternalPassCallbacksForPythonCallable();
+ mlirRegisterExternalPass(
+ passID, mlirStringRefCreate(name->data(), name->length()),
+ mlirStringRefCreate(argument.data(), argument.length()),
+ mlirStringRefCreate(description.data(), description.length()),
+ mlirStringRefCreate(opName.data(), opName.size()),
+ /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, callbacks,
+ /*userData*/ run.ptr());
+ },
+ "argument"_a, "run"_a, "name"_a.none() = nb::none(),
+ "description"_a.none() = "", "op_name"_a.none() = "",
+ "Register a python-defined pass.");
+
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
@@ -178,21 +223,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
MlirTypeID passID =
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
- MlirExternalPassCallbacks callbacks;
- callbacks.construct = [](void *obj) {
- (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
- };
- callbacks.destruct = [](void *obj) {
- (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
- };
- callbacks.initialize = nullptr;
- callbacks.clone = [](void *) -> void * {
- throw std::runtime_error("Cloning Python passes not supported");
- };
- callbacks.run = [](MlirOperation op, MlirExternalPass pass,
- void *userData) {
- nb::handle(static_cast<PyObject *>(userData))(op, pass);
- };
+ auto callbacks = createExternalPassCallbacksForPythonCallable();
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
mlirStringRefCreate(argument.data(), argument.length()),
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index b0a6ec1ace3cc..8924f6d9ec6a9 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -216,6 +216,32 @@ MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
userData)));
}
+void mlirRegisterExternalPass(MlirTypeID passID, MlirStringRef name,
+ MlirStringRef argument, MlirStringRef description,
+ MlirStringRef opName, intptr_t nDependentDialects,
+ MlirDialectHandle *dependentDialects,
+ MlirExternalPassCallbacks callbacks,
+ void *userData) {
+ // here we clone these arguments as owned and pass them to
+ // the lambda as copies to avoid dangling refs,
+ // since the lambda below lives longer than the current function
+ std::string nameStr = unwrap(name).str();
+ std::string argumentStr = unwrap(argument).str();
+ std::string descriptionStr = unwrap(description).str();
+ std::string opNameStr = unwrap(opName).str();
+ std::vector<MlirDialectHandle> dependentDialectVec(
+ dependentDialects, dependentDialects + nDependentDialects);
+
+ mlir::registerPass([passID, nameStr, argumentStr, descriptionStr, opNameStr,
+ dependentDialectVec, callbacks, userData] {
+ return std::unique_ptr<mlir::Pass>(new mlir::ExternalPass(
+ unwrap(passID), nameStr, argumentStr, descriptionStr,
+ opNameStr.length() > 0 ? std::optional<StringRef>(opNameStr)
+ : std::nullopt,
+ dependentDialectVec, callbacks, userData));
+ });
+}
+
void mlirExternalPassSignalFailure(MlirExternalPass pass) {
unwrap(pass)->signalPassFailure();
}
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index 50c42102f66d3..0fbd96ec71ddc 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -89,7 +89,7 @@ def __call__(self, op, pass_):
# test signal_pass_failure
def custom_pass_that_fails(op, pass_):
- print("hello from pass that fails")
+ print("hello from pass that fails", file=sys.stderr)
pass_.signal_pass_failure()
pm = PassManager("any")
@@ -99,4 +99,44 @@ def custom_pass_that_fails(op, pass_):
try:
pm.run(module)
except Exception as e:
- print(f"caught exception: {e}")
+ print(f"caught exception: {e}", file=sys.stderr)
+
+
+# CHECK-LABEL: TEST: testRegisterPass
+@run
+def testRegisterPass():
+ with Context():
+ pdl_module = make_pdl_module()
+ frozen = PDLModule(pdl_module).freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ def custom_pass_3(op, pass_):
+ print("hello from pass 3!!!", file=sys.stderr)
+
+ def custom_pass_4(op, pass_):
+ apply_patterns_and_fold_greedily(op, frozen)
+
+ register_pass("custom-pass-one", custom_pass_3)
+ register_pass("custom-pass-two", custom_pass_4)
+
+ pm = PassManager("any")
+ pm.enable_ir_printing()
+
+ # CHECK: hello from pass 3!!!
+ # CHECK-LABEL: Dump After custom_pass_3
+ # CHECK-LABEL: Dump After custom_pass_4
+ # CHECK: arith.muli
+ # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+ # CHECK: llvm.mul
+ pm.add("custom-pass-one, custom-pass-two, convert-arith-to-llvm")
+ pm.run(module)
|
Thanks @PragmaTwice - this is good functionality to have! Having said that, I feel the design of passes-are-functions-with-metadata is showing its limitations. I note the following "strikes against" (and contrast them with a class-oriented design):
I know you and @makslevental have already looked at alternatives, in particular, a class-based design and noted issues. Regarding @makslevental's comment:
Could it work if we pass the class object to the bindings and have the right callbacks instantiate the class from the C++ side? Presumably that would mean C++-code gets to manage the lifetimes of the instances. Maybe taking a weakref to the class definition would be enough (as the class definition should normally live until the end of the Python runtime). |
You will never be able to skirt the issue that Python leaks references like a sieve. Already "functions with metadata" are classes because of closures - there's no difference between "function that has tons of stuff in its closure" and "class with data" as far as lifetime management is concerned. There is no real way for C++ to control the lifetime of anything created in Python - not functions with data in their closures and definitely not data owned by classes which will have all sorts of their own references. The reason I strongly advocated (and will continue to advocate) functions is because at least there's some signal to the user they're doing questionable things wrt lifetime (ie they've used things in the closure).
The first two strikes are style issues (OOP vs free-function) not strikes in my opinion. This one is real and why I think registration should just not be a thing yet. Like I said in the first PR - we don't have pipelines in Python except as strings so this is "putting the cart before the horse". |
Ahh if you look at the first edition of this PR, you will find that it can be a method bound to The "base class" form looks good at first, but you know that, the "base class" cannot be a real
Yup. I think the advantage of class is that we can have a "factory object" to generate pass objects (instead of deepcopying the callable), and also static properties/methods that may benefit users. But in the current design we can also have non-static properties in the callable objects and we can use these as the input of pass descriptions/argument/name.. . |
This PR brings a new function
register_pass
to MLIR python bindings, which can register a python-defined pass into MLIR (corresponding tomlir::registerPass
in C++).An example: