-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Python] Support Python-defined passes in MLIR #156000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
8386c87
cb82621
7556ca2
d5055a7
6d2f472
1a98ae8
751fe84
7966ddd
6a9ec66
2965a9e
1dde449
ca80408
c8c2fae
01e68c5
e565ffb
9f526c7
b6080ea
f72c83c
7a491af
46b833d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,9 @@ | |
| #include "IRModule.h" | ||
| #include "mlir-c/Pass.h" | ||
| #include "mlir/Bindings/Python/Nanobind.h" | ||
|
|
||
| #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. | ||
| #include "nanobind/trampoline.h" | ||
|
||
|
|
||
| namespace nb = nanobind; | ||
| using namespace nb::literals; | ||
|
|
@@ -20,6 +22,81 @@ using namespace mlir::python; | |
|
|
||
| namespace { | ||
|
|
||
| // A base class for defining passes in Python | ||
| // Users are expected to subclass this and implement the `run` method, e.g. | ||
| // ``` | ||
| // class MyPass(mlir.passmanager.Pass): | ||
| // def __init__(self): | ||
| // super().__init__("MyPass", ..) | ||
| // # other init stuff.. | ||
| // def run(self, operation): | ||
| // # do something with operation.. | ||
| // pass | ||
| // ``` | ||
| class PyPassBase { | ||
| public: | ||
| PyPassBase(std::string name, std::string argument, std::string description, | ||
| std::string opName) | ||
| : name(std::move(name)), argument(std::move(argument)), | ||
| description(std::move(description)), opName(std::move(opName)) { | ||
| callbacks.construct = [](void *obj) {}; | ||
| callbacks.destruct = [](void *obj) { | ||
| nb::handle(static_cast<PyObject *>(obj)).dec_ref(); | ||
| }; | ||
| callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) { | ||
| auto handle = nb::handle(static_cast<PyObject *>(obj)); | ||
| nb::cast<PyPassBase *>(handle)->run(op); | ||
| }; | ||
| callbacks.clone = [](void *obj) -> void * { | ||
| nb::object copy = nb::module_::import_("copy"); | ||
| nb::object deepcopy = copy.attr("deepcopy"); | ||
| return deepcopy(obj).release().ptr(); | ||
| }; | ||
| callbacks.initialize = nullptr; | ||
|
||
| } | ||
|
|
||
| // this method should be overridden by subclasses in Python. | ||
| virtual void run(MlirOperation op) = 0; | ||
|
|
||
| virtual ~PyPassBase() = default; | ||
|
|
||
| // Make an MlirPass instance on-the-fly that wraps this object. | ||
| // Note that passmanager will take the ownership of the returned | ||
| // object and release it when appropriate. | ||
| MlirPass make() { | ||
|
||
| auto *obj = nb::find(this).release().ptr(); | ||
|
||
| return mlirCreateExternalPass( | ||
| mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()), | ||
| mlirStringRefCreate(argument.data(), argument.length()), | ||
| mlirStringRefCreate(description.data(), description.length()), | ||
| mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr, | ||
|
||
| callbacks, obj); | ||
| } | ||
|
|
||
| const std::string &getName() const { return name; } | ||
| const std::string &getArgument() const { return argument; } | ||
| const std::string &getDescription() const { return description; } | ||
| const std::string &getOpName() const { return opName; } | ||
|
|
||
| private: | ||
| MlirExternalPassCallbacks callbacks; | ||
|
|
||
| std::string name; | ||
| std::string argument; | ||
| std::string description; | ||
| std::string opName; | ||
| }; | ||
|
|
||
| // A trampoline class upon PyPassBase. | ||
| // Refer to | ||
| // https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python | ||
| class PyPass : PyPassBase { | ||
| public: | ||
| NB_TRAMPOLINE(PyPassBase, 1); | ||
|
|
||
| void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); } | ||
| }; | ||
|
|
||
| /// Owning Wrapper around a PassManager. | ||
| class PyPassManager { | ||
| public: | ||
|
|
@@ -52,6 +129,26 @@ class PyPassManager { | |
|
|
||
| } // namespace | ||
|
|
||
| void mlir::python::populatePassSubmodule(nanobind::module_ &m) { | ||
| //---------------------------------------------------------------------------- | ||
| // Mapping of the Python-defined Pass interface | ||
| //---------------------------------------------------------------------------- | ||
| nb::class_<PyPassBase, PyPass>(m, "Pass") | ||
|
||
| .def(nb::init<std::string, std::string, std::string, std::string>(), | ||
| "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "", | ||
| "op_name"_a = "", "Create a new Pass.") | ||
| .def("run", &PyPassBase::run, "operation"_a, | ||
| "Run the pass on the provided operation.") | ||
| .def_prop_ro("name", | ||
| [](const PyPassBase &self) { return self.getName(); }) | ||
| .def_prop_ro("argument", | ||
| [](const PyPassBase &self) { return self.getArgument(); }) | ||
| .def_prop_ro("description", | ||
| [](const PyPassBase &self) { return self.getDescription(); }) | ||
| .def_prop_ro("op_name", | ||
| [](const PyPassBase &self) { return self.getOpName(); }); | ||
| } | ||
|
|
||
| /// Create the `mlir.passmanager` here. | ||
| void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { | ||
| //---------------------------------------------------------------------------- | ||
|
|
@@ -157,6 +254,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { | |
| "pipeline"_a, | ||
| "Add textual pipeline elements to the pass manager. Throws a " | ||
| "ValueError if the pipeline can't be parsed.") | ||
| .def( | ||
| "add", | ||
| [](PyPassManager &passManager, PyPassBase &pass) { | ||
| mlirPassManagerAddOwnedPass(passManager.get(), pass.make()); | ||
| }, | ||
| "pass"_a, "Add a python-defined pass to the pass manager.") | ||
| .def( | ||
| "run", | ||
| [](PyPassManager &passManager, PyOperationBase &op, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # RUN: %PYTHON %s 2>&1 | FileCheck %s | ||
|
|
||
| import gc, sys | ||
| from mlir.ir import * | ||
| from mlir.passmanager import * | ||
| from mlir.dialects.builtin import ModuleOp | ||
| from mlir.dialects import pdl | ||
| from mlir.rewrite import * | ||
|
|
||
|
|
||
| def log(*args): | ||
| print(*args, file=sys.stderr) | ||
| sys.stderr.flush() | ||
|
|
||
|
|
||
| def run(f): | ||
| log("\nTEST:", f.__name__) | ||
| f() | ||
| gc.collect() | ||
| assert Context._get_live_count() == 0 | ||
|
|
||
|
|
||
| def make_pdl_module(): | ||
| with Location.unknown(): | ||
| pdl_module = Module.create() | ||
| with InsertionPoint(pdl_module.body): | ||
| # Change all arith.addi with index types to arith.muli. | ||
| @pdl.pattern(benefit=1, sym_name="addi_to_mul") | ||
| def pat(): | ||
| # Match arith.addi with index types. | ||
| i64_type = pdl.TypeOp(IntegerType.get_signless(64)) | ||
| operand0 = pdl.OperandOp(i64_type) | ||
| operand1 = pdl.OperandOp(i64_type) | ||
| op0 = pdl.OperationOp( | ||
| name="arith.addi", args=[operand0, operand1], types=[i64_type] | ||
| ) | ||
|
|
||
| # Replace the matched op with arith.muli. | ||
| @pdl.rewrite() | ||
| def rew(): | ||
| newOp = pdl.OperationOp( | ||
| name="arith.muli", args=[operand0, operand1], types=[i64_type] | ||
| ) | ||
| pdl.ReplaceOp(op0, with_op=newOp) | ||
|
|
||
| return pdl_module | ||
|
|
||
|
|
||
| # CHECK-LABEL: TEST: testCustomPass | ||
| @run | ||
| def testCustomPass(): | ||
| with Context(): | ||
| pdl_module = make_pdl_module() | ||
|
|
||
| class CustomPass(Pass): | ||
| def __init__(self): | ||
| super().__init__("CustomPass", op_name="builtin.module") | ||
|
|
||
| def run(self, m): | ||
| frozen = PDLModule(pdl_module).freeze() | ||
| apply_patterns_and_fold_greedily_with_op(m, frozen) | ||
|
|
||
| module = ModuleOp.parse( | ||
| r""" | ||
| module { | ||
| func.func @add(%a: i64, %b: i64) -> i64 { | ||
| %sum = arith.addi %a, %b : i64 | ||
| return %sum : i64 | ||
| } | ||
| } | ||
| """ | ||
| ) | ||
|
|
||
| pm = PassManager("any") | ||
| pm.enable_ir_printing() | ||
|
|
||
| # CHECK-LABEL: Dump After CustomPass | ||
| # CHECK: arith.muli | ||
| pm.add(CustomPass()) | ||
| # CHECK-LABEL: Dump After ArithToLLVMConversionPass | ||
| # CHECK: llvm.mul | ||
| pm.add("convert-arith-to-llvm") | ||
| pm.run(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that
passis a keyword in python so that name likemlir.pass.Passdoesn't work. 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, pass vs pass :-) I don't have a good suggestion given how established that is MLIR/compilers and Python side ... MlirPass or PyPass or pydefpass would be most obvious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the full name of the base class is
mlir.passmanager.Pass. Is it good enough or we'd better to rename it with another module name?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that PEP8 recommends appending a underscore in case of names of arguments / attributes clashing with reserved keywords, I feel
mlir.pass_.Passis an option as it is not too surprising for me personally.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think putting a trailing underscore in a namespace is a little awkward. How about
mlir.passes? ormlir.passinfra? or I dunno something like that that I'm failing to come up with right now...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. For now, I will rename it to
mlir.passes.Pass. Let me know if anyone think of a better name : )There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 01e68c5.