-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Python] Support Python-defined passes in MLIR #156000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
8386c87
cb82621
7556ca2
d5055a7
6d2f472
1a98ae8
751fe84
7966ddd
6a9ec66
2965a9e
1dde449
ca80408
c8c2fae
01e68c5
e565ffb
9f526c7
b6080ea
f72c83c
7a491af
46b833d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) { | |
| auto passModule = | ||
| m.def_submodule("passmanager", "MLIR Pass Management Bindings"); | ||
| populatePassManagerSubmodule(passModule); | ||
| populatePassSubmodule(passModule); | ||
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,9 @@ | |
| #include "IRModule.h" | ||
| #include "mlir-c/Pass.h" | ||
| #include "mlir/Bindings/Python/Nanobind.h" | ||
|
|
||
| #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. | ||
| #include "nanobind/trampoline.h" | ||
|
||
|
|
||
| namespace nb = nanobind; | ||
| using namespace nb::literals; | ||
|
|
@@ -20,6 +22,79 @@ using namespace mlir::python; | |
|
|
||
| namespace { | ||
|
|
||
| // A base class for defining passes in Python | ||
| // Users are expected to subclass this and implement the `run` method, e.g. | ||
| // ``` | ||
| // class MyPass(mlir.passmanager.Pass): | ||
| // def run(self, operation): | ||
| // # do something with operation | ||
| // pass | ||
| // ``` | ||
| class PyPassBase { | ||
| public: | ||
| PyPassBase(std::string name, std::string argument, std::string description, | ||
| std::string opName) | ||
| : name(std::move(name)), argument(std::move(argument)), | ||
| description(std::move(description)), opName(std::move(opName)) { | ||
| callbacks.construct = [](void *obj) {}; | ||
| callbacks.destruct = [](void *obj) { | ||
| nb::handle(static_cast<PyObject *>(obj)).dec_ref(); | ||
| }; | ||
| callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) { | ||
| auto handle = nb::handle(static_cast<PyObject *>(obj)); | ||
| nb::cast<PyPassBase *>(handle)->run(op); | ||
| }; | ||
| callbacks.clone = [](void *obj) -> void * { | ||
| nb::object copy = nb::module_::import_("copy"); | ||
| nb::object deepcopy = copy.attr("deepcopy"); | ||
| return deepcopy(obj).release().ptr(); | ||
| }; | ||
| callbacks.initialize = nullptr; | ||
|
||
| } | ||
|
|
||
| // this method should be overridden by subclasses in Python. | ||
| virtual void run(MlirOperation op) = 0; | ||
|
|
||
| virtual ~PyPassBase() = default; | ||
|
|
||
| // Make an MlirPass instance on-the-fly that wraps this object. | ||
| // Note that passmanager will take the ownership of the returned | ||
| // object and release it when appropriate. | ||
| // Also, `*this` must remain alive as long as the returned object is alive. | ||
| MlirPass make() { | ||
| auto *obj = nb::find(this).release().ptr(); | ||
|
||
| return mlirCreateExternalPass( | ||
| mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()), | ||
| mlirStringRefCreate(argument.data(), argument.length()), | ||
| mlirStringRefCreate(description.data(), description.length()), | ||
| mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr, | ||
|
||
| callbacks, obj); | ||
| } | ||
|
|
||
| const std::string &getName() const { return name; } | ||
| const std::string &getArgument() const { return argument; } | ||
| const std::string &getDescription() const { return description; } | ||
| const std::string &getOpName() const { return opName; } | ||
|
|
||
| private: | ||
| MlirExternalPassCallbacks callbacks; | ||
|
|
||
| std::string name; | ||
| std::string argument; | ||
| std::string description; | ||
| std::string opName; | ||
| }; | ||
|
|
||
| // A trampoline class upon PyPassBase. | ||
| // Refer to | ||
| // https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python | ||
| class PyPass : PyPassBase { | ||
| public: | ||
| NB_TRAMPOLINE(PyPassBase, 1); | ||
|
|
||
| void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); } | ||
| }; | ||
|
|
||
| /// Owning Wrapper around a PassManager. | ||
| class PyPassManager { | ||
| public: | ||
|
|
@@ -52,6 +127,26 @@ class PyPassManager { | |
|
|
||
| } // namespace | ||
|
|
||
| void mlir::python::populatePassSubmodule(nanobind::module_ &m) { | ||
| //---------------------------------------------------------------------------- | ||
| // Mapping of the Python-defined Pass interface | ||
| //---------------------------------------------------------------------------- | ||
| nb::class_<PyPassBase, PyPass>(m, "Pass") | ||
|
||
| .def(nb::init<std::string, std::string, std::string, std::string>(), | ||
| "name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "", | ||
| "op_name"_a = "", "Create a new Pass.") | ||
| .def("run", &PyPassBase::run, "operation"_a, | ||
| "Run the pass on the provided operation.") | ||
| .def_prop_ro("name", | ||
| [](const PyPassBase &self) { return self.getName(); }) | ||
| .def_prop_ro("argument", | ||
| [](const PyPassBase &self) { return self.getArgument(); }) | ||
| .def_prop_ro("description", | ||
| [](const PyPassBase &self) { return self.getDescription(); }) | ||
| .def_prop_ro("op_name", | ||
| [](const PyPassBase &self) { return self.getOpName(); }); | ||
| } | ||
|
|
||
| /// Create the `mlir.passmanager` here. | ||
| void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { | ||
| //---------------------------------------------------------------------------- | ||
|
|
@@ -157,6 +252,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { | |
| "pipeline"_a, | ||
| "Add textual pipeline elements to the pass manager. Throws a " | ||
| "ValueError if the pipeline can't be parsed.") | ||
| .def( | ||
| "add", | ||
| [](PyPassManager &passManager, PyPassBase &pass) { | ||
| mlirPassManagerAddOwnedPass(passManager.get(), pass.make()); | ||
| }, | ||
| "pass"_a, "Add a python-defined pass to the pass manager.") | ||
| .def( | ||
| "run", | ||
| [](PyPassManager &passManager, PyOperationBase &op, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # RUN: %PYTHON %s 2>&1 | FileCheck %s | ||
|
|
||
| import gc, sys | ||
| from mlir.ir import * | ||
| from mlir.passmanager import * | ||
| from mlir.dialects.builtin import ModuleOp | ||
| from mlir.dialects import pdl | ||
| from mlir.rewrite import * | ||
|
|
||
|
|
||
| def log(*args): | ||
| print(*args, file=sys.stderr) | ||
| sys.stderr.flush() | ||
|
|
||
|
|
||
| def run(f): | ||
| log("\nTEST:", f.__name__) | ||
| f() | ||
| gc.collect() | ||
| assert Context._get_live_count() == 0 | ||
|
|
||
|
|
||
| def make_pdl_module(): | ||
| with Location.unknown(): | ||
| pdl_module = Module.create() | ||
| with InsertionPoint(pdl_module.body): | ||
| # Change all arith.addi with index types to arith.muli. | ||
| @pdl.pattern(benefit=1, sym_name="addi_to_mul") | ||
| def pat(): | ||
| # Match arith.addi with index types. | ||
| index_type = pdl.TypeOp(IndexType.get()) | ||
| operand0 = pdl.OperandOp(index_type) | ||
| operand1 = pdl.OperandOp(index_type) | ||
| op0 = pdl.OperationOp( | ||
| name="arith.addi", args=[operand0, operand1], types=[index_type] | ||
| ) | ||
|
|
||
| # Replace the matched op with arith.muli. | ||
| @pdl.rewrite() | ||
| def rew(): | ||
| newOp = pdl.OperationOp( | ||
| name="arith.muli", args=[operand0, operand1], types=[index_type] | ||
| ) | ||
| pdl.ReplaceOp(op0, with_op=newOp) | ||
|
|
||
| return pdl_module | ||
|
|
||
|
|
||
| # CHECK-LABEL: TEST: testCustomPass | ||
| @run | ||
| def testCustomPass(): | ||
| with Context(): | ||
| pdl_module = make_pdl_module() | ||
|
|
||
| class CustomPass(Pass): | ||
| def __init__(self): | ||
| super().__init__("CustomPass", op_name="builtin.module") | ||
|
|
||
| def run(self, m): | ||
| frozen = PDLModule(pdl_module).freeze() | ||
| apply_patterns_and_fold_greedily_for_op(m, frozen) | ||
|
|
||
| module = ModuleOp.parse( | ||
| r""" | ||
| module { | ||
| func.func @add(%a: index, %b: index) -> index { | ||
| %sum = arith.addi %a, %b : index | ||
| return %sum : index | ||
| } | ||
| } | ||
| """ | ||
| ) | ||
|
|
||
| # CHECK-LABEL: Dump After CustomPass | ||
| # CHECK: arith.muli | ||
| pm = PassManager("any") | ||
| pm.enable_ir_printing() | ||
| pm.add(CustomPass()) | ||
| pm.run(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure which name is suitable.. So the current name is just a placeholder. And modifying
mlirApplyPatternsAndFoldGreedilyseems a breaking change.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a breaking change, but not sure how widely and also C API is really best effort stable (meaning, we try not to break it). I probably should have included a Module suffix on the original, and then this could have gone without. How many usages in the wild can you find of this call?
(Nit: ForOp suffix makes me think it's related to ForOp).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I perform a simple check on public repos via GitHub search (https://github.com/search?q=mlirApplyPatternsAndFoldGreedily+&type=code) and there seems to be some use cases of this API so I'm not sure if we should change it or not (maybe leave to maintainers for this decision : ).
Currently I'm going to rename it to
mlirApplyPatternsAndFoldGreedilyWithOpto avoid such confusion withForOpas you mentioned : )There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please extract this change in an independent PR? This is largely unrelated (your test python pass could just do another kind of rewrite right now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed in ca80408.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh got it.
I think I can do it via op/block/region/walk/erase APIs directly, but it maybe a little ugly since the normal rewriter API is not ported to python yet AFAIK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep - no rewriter API in Python but you can use relevant methods on ops directly, etc. (Here's one example: https://github.com/libxsmm/tpp-mlir/pull/1064/files#diff-7aa62724b21b998da9cf032da6e6b77bbdb664258a5dca850f9509a5459646f6R110-R118 - @makslevental has more.)
That should be enough to demonstrate that your test pass is actually running and, if kept simple, will not be that ugly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just use
apply_patterns_and_fold_greedilyhere and then addapply_patterns_and_fold_greedily_with_opin a follow-up (the existing test passes just fine withapply_patterns_and_fold_greedily).