Skip to content

Commit e565ffb

Browse files
Merge changes from llvm#157369
Co-authored-by: Maksim Levental <[email protected]>
1 parent 01e68c5 commit e565ffb

File tree

7 files changed

+57
-122
lines changed

7 files changed

+57
-122
lines changed

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,4 @@ NB_MODULE(_mlir, m) {
139139
auto passManagerModule =
140140
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
141141
populatePassManagerSubmodule(passManagerModule);
142-
auto passesModule =
143-
m.def_submodule("passes", "MLIR Pass Infrastructure Bindings");
144-
populatePassSubmodule(passesModule);
145142
}

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 37 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include "mlir/Bindings/Python/Nanobind.h"
1515
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1616
// clang-format on
17-
#include "nanobind/trampoline.h"
1817

1918
namespace nb = nanobind;
2019
using namespace nb::literals;
@@ -23,81 +22,6 @@ using namespace mlir::python;
2322

2423
namespace {
2524

26-
// A base class for defining passes in Python
27-
// Users are expected to subclass this and implement the `run` method, e.g.
28-
// ```
29-
// class MyPass(Pass):
30-
// def __init__(self):
31-
// super().__init__("MyPass", ..)
32-
// # other init stuff..
33-
// def run(self, operation):
34-
// # do something with operation..
35-
// pass
36-
// ```
37-
class PyPassBase {
38-
public:
39-
PyPassBase(std::string name, std::string argument, std::string description,
40-
std::string opName)
41-
: name(std::move(name)), argument(std::move(argument)),
42-
description(std::move(description)), opName(std::move(opName)) {
43-
callbacks.construct = [](void *obj) {};
44-
callbacks.destruct = [](void *obj) {
45-
nb::handle(static_cast<PyObject *>(obj)).dec_ref();
46-
};
47-
callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
48-
auto handle = nb::handle(static_cast<PyObject *>(obj));
49-
nb::cast<PyPassBase *>(handle)->run(op);
50-
};
51-
callbacks.clone = [](void *obj) -> void * {
52-
nb::object copy = nb::module_::import_("copy");
53-
nb::object deepcopy = copy.attr("deepcopy");
54-
return deepcopy(obj).release().ptr();
55-
};
56-
callbacks.initialize = nullptr;
57-
}
58-
59-
// this method should be overridden by subclasses in Python.
60-
virtual void run(MlirOperation op) = 0;
61-
62-
virtual ~PyPassBase() = default;
63-
64-
// Make an MlirPass instance on-the-fly that wraps this object.
65-
// Note that passmanager will take the ownership of the returned
66-
// object and release it when appropriate.
67-
MlirPass make() {
68-
auto *obj = nb::find(this).release().ptr();
69-
return mlirCreateExternalPass(
70-
mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
71-
mlirStringRefCreate(argument.data(), argument.length()),
72-
mlirStringRefCreate(description.data(), description.length()),
73-
mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
74-
callbacks, obj);
75-
}
76-
77-
const std::string &getName() const { return name; }
78-
const std::string &getArgument() const { return argument; }
79-
const std::string &getDescription() const { return description; }
80-
const std::string &getOpName() const { return opName; }
81-
82-
private:
83-
MlirExternalPassCallbacks callbacks;
84-
85-
std::string name;
86-
std::string argument;
87-
std::string description;
88-
std::string opName;
89-
};
90-
91-
// A trampoline class upon PyPassBase.
92-
// Refer to
93-
// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
94-
class PyPass : PyPassBase {
95-
public:
96-
NB_TRAMPOLINE(PyPassBase, 1);
97-
98-
void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
99-
};
100-
10125
/// Owning Wrapper around a PassManager.
10226
class PyPassManager {
10327
public:
@@ -130,26 +54,6 @@ class PyPassManager {
13054

13155
} // namespace
13256

133-
void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
134-
//----------------------------------------------------------------------------
135-
// Mapping of the Python-defined Pass interface
136-
//----------------------------------------------------------------------------
137-
nb::class_<PyPassBase, PyPass>(m, "Pass")
138-
.def(nb::init<std::string, std::string, std::string, std::string>(),
139-
"name"_a, nb::kw_only(), "argument"_a = "", "description"_a = "",
140-
"op_name"_a = "", "Create a new Pass.")
141-
.def("run", &PyPassBase::run, "operation"_a,
142-
"Run the pass on the provided operation.")
143-
.def_prop_ro("name",
144-
[](const PyPassBase &self) { return self.getName(); })
145-
.def_prop_ro("argument",
146-
[](const PyPassBase &self) { return self.getArgument(); })
147-
.def_prop_ro("description",
148-
[](const PyPassBase &self) { return self.getDescription(); })
149-
.def_prop_ro("op_name",
150-
[](const PyPassBase &self) { return self.getOpName(); });
151-
}
152-
15357
/// Create the `mlir.passmanager` here.
15458
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
15559
//----------------------------------------------------------------------------
@@ -256,11 +160,44 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
256160
"Add textual pipeline elements to the pass manager. Throws a "
257161
"ValueError if the pipeline can't be parsed.")
258162
.def(
259-
"add",
260-
[](PyPassManager &passManager, PyPassBase &pass) {
261-
mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
163+
"add_python_pass",
164+
[](PyPassManager &passManager, const nb::callable &run,
165+
std::optional<std::string> &name, const std::string &argument,
166+
const std::string &description, const std::string &opName) {
167+
if (!name.has_value()) {
168+
name = nb::cast<std::string>(
169+
nb::borrow<nb::str>(run.attr("__name__")));
170+
}
171+
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
172+
MlirTypeID passID =
173+
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
174+
MlirExternalPassCallbacks callbacks;
175+
callbacks.construct = [](void *obj) {
176+
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
177+
};
178+
callbacks.destruct = [](void *obj) {
179+
(void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
180+
};
181+
callbacks.initialize = nullptr;
182+
callbacks.clone = [](void *) -> void * {
183+
throw std::runtime_error("Cloning Python passes not supported");
184+
};
185+
callbacks.run = [](MlirOperation op, MlirExternalPass,
186+
void *userData) {
187+
nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
188+
};
189+
auto externalPass = mlirCreateExternalPass(
190+
passID, mlirStringRefCreate(name->data(), name->length()),
191+
mlirStringRefCreate(argument.data(), argument.length()),
192+
mlirStringRefCreate(description.data(), description.length()),
193+
mlirStringRefCreate(opName.data(), opName.size()),
194+
/*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
195+
callbacks, /*userData*/ run.ptr());
196+
mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
262197
},
263-
"pass"_a, "Add a python-defined pass to the pass manager.")
198+
"run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
199+
"description"_a.none() = "", "op_name"_a.none() = "",
200+
"Add a python-defined pass to the pass manager.")
264201
.def(
265202
"run",
266203
[](PyPassManager &passManager, PyOperationBase &op,

mlir/lib/Bindings/Python/Pass.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ namespace mlir {
1515
namespace python {
1616

1717
void populatePassManagerSubmodule(nanobind::module_ &m);
18-
void populatePassSubmodule(nanobind::module_ &m);
1918

2019
} // namespace python
2120
} // namespace mlir

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
145145
: Pass(passID, opName), id(passID), name(name), argument(argument),
146146
description(description), dependentDialects(dependentDialects),
147147
callbacks(callbacks), userData(userData) {
148-
callbacks.construct(userData);
148+
if (callbacks.construct)
149+
callbacks.construct(userData);
149150
}
150151

151-
~ExternalPass() override { callbacks.destruct(userData); }
152+
~ExternalPass() override {
153+
if (callbacks.destruct)
154+
callbacks.destruct(userData);
155+
}
152156

153157
StringRef getName() const override { return name; }
154158
StringRef getArgument() const override { return argument; }

mlir/python/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
2020
SOURCES
2121
_mlir_libs/__init__.py
2222
ir.py
23-
passes.py
2423
passmanager.py
2524
rewrite.py
2625
dialects/_ods_common.py

mlir/python/mlir/passes.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

mlir/test/python/pass.py renamed to mlir/test/python/python_pass.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import gc, sys
44
from mlir.ir import *
55
from mlir.passmanager import *
6-
from mlir.passes import *
76
from mlir.dialects.builtin import ModuleOp
87
from mlir.dialects import pdl
98
from mlir.rewrite import *
@@ -54,13 +53,6 @@ def testCustomPass():
5453
pdl_module = make_pdl_module()
5554
frozen = PDLModule(pdl_module).freeze()
5655

57-
class CustomPass(Pass):
58-
def __init__(self):
59-
super().__init__("CustomPass", op_name="builtin.module")
60-
61-
def run(self, m):
62-
apply_patterns_and_fold_greedily_with_op(m, frozen)
63-
6456
module = ModuleOp.parse(
6557
r"""
6658
module {
@@ -72,12 +64,24 @@ def run(self, m):
7264
"""
7365
)
7466

67+
def custom_pass_1(op):
68+
print("hello from pass 1!!!", file=sys.stderr)
69+
70+
class CustomPass2:
71+
def __call__(self, m):
72+
apply_patterns_and_fold_greedily_with_op(m, frozen)
73+
74+
custom_pass_2 = CustomPass2()
75+
7576
pm = PassManager("any")
7677
pm.enable_ir_printing()
7778

78-
# CHECK-LABEL: Dump After CustomPass
79+
# CHECK: hello from pass 1!!!
80+
# CHECK-LABEL: Dump After custom_pass_1
81+
pm.add_python_pass(custom_pass_1)
82+
# CHECK-LABEL: Dump After CustomPass2
7983
# CHECK: arith.muli
80-
pm.add(CustomPass())
84+
pm.add_python_pass(custom_pass_2, "CustomPass2")
8185
# CHECK-LABEL: Dump After ArithToLLVMConversionPass
8286
# CHECK: llvm.mul
8387
pm.add("convert-arith-to-llvm")

0 commit comments

Comments
 (0)