-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Python] Support Python-defined passes in MLIR #156000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
8386c87
cb82621
7556ca2
d5055a7
6d2f472
1a98ae8
751fe84
7966ddd
6a9ec66
2965a9e
1dde449
ca80408
c8c2fae
01e68c5
e565ffb
9f526c7
b6080ea
f72c83c
7a491af
46b833d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,8 +10,10 @@ | |
|
|
||
| #include "IRModule.h" | ||
| #include "mlir-c/Pass.h" | ||
| // clang-format off | ||
| #include "mlir/Bindings/Python/Nanobind.h" | ||
| #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. | ||
| // clang-format on | ||
|
|
||
| namespace nb = nanobind; | ||
| using namespace nb::literals; | ||
|
|
@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { | |
| "pipeline"_a, | ||
| "Add textual pipeline elements to the pass manager. Throws a " | ||
| "ValueError if the pipeline can't be parsed.") | ||
| .def( | ||
| "add_python_pass", | ||
| [](PyPassManager &passManager, const nb::callable &run, | ||
| std::optional<std::string> &name, const std::string &argument, | ||
| const std::string &description, const std::string &opName) { | ||
| if (!name.has_value()) { | ||
| name = nb::cast<std::string>( | ||
| nb::borrow<nb::str>(run.attr("__name__"))); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: what happens on lambdas?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. >>> x = lambda: "bob"
>>> x.__name__
'<lambda>'
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not the best thing ever but it doesn't blow up (and also I don't see anyone using a lambda here...). |
||
| } | ||
| MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); | ||
| MlirTypeID passID = | ||
| mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); | ||
| MlirExternalPassCallbacks callbacks; | ||
| callbacks.construct = [](void *obj) { | ||
| (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref(); | ||
| }; | ||
| callbacks.destruct = [](void *obj) { | ||
| (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref(); | ||
| }; | ||
| callbacks.initialize = nullptr; | ||
| callbacks.clone = [](void *) -> void * { | ||
| throw std::runtime_error("Cloning Python passes not supported"); | ||
| }; | ||
| callbacks.run = [](MlirOperation op, MlirExternalPass, | ||
| void *userData) { | ||
| nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op); | ||
| }; | ||
| auto externalPass = mlirCreateExternalPass( | ||
| passID, mlirStringRefCreate(name->data(), name->length()), | ||
| mlirStringRefCreate(argument.data(), argument.length()), | ||
| mlirStringRefCreate(description.data(), description.length()), | ||
| mlirStringRefCreate(opName.data(), opName.size()), | ||
| /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, | ||
| callbacks, /*userData*/ run.ptr()); | ||
| mlirPassManagerAddOwnedPass(passManager.get(), externalPass); | ||
| }, | ||
| "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "", | ||
| "description"_a.none() = "", "op_name"_a.none() = "", | ||
| "Add a python-defined pass to the pass manager.") | ||
| .def( | ||
| "run", | ||
| [](PyPassManager &passManager, PyOperationBase &op, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,14 +99,24 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { | |
| .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, | ||
| &PyFrozenRewritePatternSet::createFromCapsule); | ||
| m.def( | ||
| "apply_patterns_and_fold_greedily", | ||
| [](MlirModule module, MlirFrozenRewritePatternSet set) { | ||
| auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); | ||
| if (mlirLogicalResultIsFailure(status)) | ||
| // FIXME: Not sure this is the right error to throw here. | ||
| throw nb::value_error("pattern application failed to converge"); | ||
| }, | ||
| "module"_a, "set"_a, | ||
| "Applys the given patterns to the given module greedily while folding " | ||
| "results."); | ||
| "apply_patterns_and_fold_greedily", | ||
| [](MlirModule module, MlirFrozenRewritePatternSet set) { | ||
| auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); | ||
| if (mlirLogicalResultIsFailure(status)) | ||
| throw std::runtime_error("pattern application failed to converge"); | ||
| }, | ||
| "module"_a, "set"_a, | ||
| "Applys the given patterns to the given module greedily while folding " | ||
| "results.") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some kind of explanation of how this works would help a lot. Doesn't have to be in this file, but I'm struggling to understand what this pass does from the description you've provided.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the pass itself or this method? the pass itself is from the original PR for exposing FrozenPatternRewriter but you can take a look at #157487 which just landed and slightly refactored. |
||
| .def( | ||
| "apply_patterns_and_fold_greedily_with_op", | ||
| [](MlirOperation op, MlirFrozenRewritePatternSet set) { | ||
| auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {}); | ||
| if (mlirLogicalResultIsFailure(status)) | ||
| throw std::runtime_error( | ||
| "pattern application failed to converge"); | ||
| }, | ||
| "op"_a, "set"_a, | ||
| "Applys the given patterns to the given op greedily while folding " | ||
| "results."); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # RUN: %PYTHON %s 2>&1 | FileCheck %s | ||
|
|
||
| import gc, sys | ||
| from mlir.ir import * | ||
| from mlir.passmanager import * | ||
| from mlir.dialects.builtin import ModuleOp | ||
| from mlir.dialects import pdl | ||
| from mlir.rewrite import * | ||
|
|
||
|
|
||
| def log(*args): | ||
| print(*args, file=sys.stderr) | ||
| sys.stderr.flush() | ||
|
|
||
|
|
||
| def run(f): | ||
| log("\nTEST:", f.__name__) | ||
| f() | ||
| gc.collect() | ||
| assert Context._get_live_count() == 0 | ||
|
|
||
|
|
||
| def make_pdl_module(): | ||
| with Location.unknown(): | ||
| pdl_module = Module.create() | ||
| with InsertionPoint(pdl_module.body): | ||
| # Change all arith.addi with index types to arith.muli. | ||
| @pdl.pattern(benefit=1, sym_name="addi_to_mul") | ||
| def pat(): | ||
| # Match arith.addi with index types. | ||
| i64_type = pdl.TypeOp(IntegerType.get_signless(64)) | ||
| operand0 = pdl.OperandOp(i64_type) | ||
| operand1 = pdl.OperandOp(i64_type) | ||
| op0 = pdl.OperationOp( | ||
| name="arith.addi", args=[operand0, operand1], types=[i64_type] | ||
| ) | ||
|
|
||
| # Replace the matched op with arith.muli. | ||
| @pdl.rewrite() | ||
| def rew(): | ||
| newOp = pdl.OperationOp( | ||
| name="arith.muli", args=[operand0, operand1], types=[i64_type] | ||
| ) | ||
| pdl.ReplaceOp(op0, with_op=newOp) | ||
|
|
||
| return pdl_module | ||
|
|
||
|
|
||
| # CHECK-LABEL: TEST: testCustomPass | ||
| @run | ||
| def testCustomPass(): | ||
| with Context(): | ||
| pdl_module = make_pdl_module() | ||
| frozen = PDLModule(pdl_module).freeze() | ||
|
|
||
| module = ModuleOp.parse( | ||
| r""" | ||
| module { | ||
| func.func @add(%a: i64, %b: i64) -> i64 { | ||
| %sum = arith.addi %a, %b : i64 | ||
| return %sum : i64 | ||
| } | ||
| } | ||
| """ | ||
| ) | ||
|
|
||
| def custom_pass_1(op): | ||
| print("hello from pass 1!!!", file=sys.stderr) | ||
|
|
||
| class CustomPass2: | ||
| def __call__(self, m): | ||
| apply_patterns_and_fold_greedily_with_op(m, frozen) | ||
|
|
||
| custom_pass_2 = CustomPass2() | ||
|
|
||
| pm = PassManager("any") | ||
| pm.enable_ir_printing() | ||
|
|
||
| # CHECK: hello from pass 1!!! | ||
| # CHECK-LABEL: Dump After custom_pass_1 | ||
| pm.add_python_pass(custom_pass_1) | ||
| # CHECK-LABEL: Dump After CustomPass2 | ||
| # CHECK: arith.muli | ||
| pm.add_python_pass(custom_pass_2, "CustomPass2") | ||
| # CHECK-LABEL: Dump After ArithToLLVMConversionPass | ||
| # CHECK: llvm.mul | ||
| pm.add("convert-arith-to-llvm") | ||
| pm.run(module) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just
"add_pass"(similar to C++-API)? Or even just reusing the current"add"and dispatching on arg types.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Done in 46b833d : )