Skip to content

Commit 28638ab

Browse files
committed
drop the setattr design
1 parent 2241e27 commit 28638ab

File tree

2 files changed

+12
-32
lines changed

2 files changed

+12
-32
lines changed

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ class PyPassManager {
5656

5757
/// Create the `mlir.passmanager` here.
5858
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
59-
constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__";
59+
//----------------------------------------------------------------------------
60+
// Mapping of MlirExternalPass
61+
//----------------------------------------------------------------------------
62+
nb::class_<MlirExternalPass>(m, "ExternalPass")
63+
.def("signal_failure",
64+
[](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
6065

6166
//----------------------------------------------------------------------------
6267
// Mapping of the top-level PassManager
@@ -186,27 +191,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
186191
};
187192
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
188193
void *userData) {
189-
auto callable =
190-
nb::borrow<nb::callable>(static_cast<PyObject *>(userData));
191-
nb::setattr(callable, mlirExternalPassAttr,
192-
nb::capsule(pass.ptr));
193-
callable(op);
194-
// delete it to avoid that it is used after
195-
// the external pass is freed by the pass manager
196-
nb::delattr(callable, mlirExternalPassAttr);
194+
nb::handle(static_cast<PyObject *>(userData))(op, pass);
197195
};
198-
nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() {
199-
nb::capsule cap;
200-
try {
201-
cap = run.attr(mlirExternalPassAttr);
202-
} catch (nb::python_error &e) {
203-
throw std::runtime_error(
204-
"signal_pass_failure() should always be called "
205-
"from the __call__ method");
206-
}
207-
mlirExternalPassSignalFailure(
208-
MlirExternalPass{cap.data()});
209-
}));
210196
auto externalPass = mlirCreateExternalPass(
211197
passID, mlirStringRefCreate(name->data(), name->length()),
212198
mlirStringRefCreate(argument.data(), argument.length()),

mlir/test/python/python_pass.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def testCustomPass():
6464
"""
6565
)
6666

67-
def custom_pass_1(op):
67+
def custom_pass_1(op, pass_):
6868
print("hello from pass 1!!!", file=sys.stderr)
6969

7070
class CustomPass2:
71-
def __call__(self, m):
72-
apply_patterns_and_fold_greedily(m, frozen)
71+
def __call__(self, op, pass_):
72+
apply_patterns_and_fold_greedily(op, frozen)
7373

7474
custom_pass_2 = CustomPass2()
7575

@@ -89,9 +89,9 @@ def __call__(self, m):
8989

9090
# test signal_pass_failure
9191
class CustomPassThatFails:
92-
def __call__(self, m):
92+
def __call__(self, op, pass_):
9393
print("hello from pass that fails")
94-
self.signal_pass_failure()
94+
pass_.signal_failure()
9595

9696
custom_pass_that_fails = CustomPassThatFails()
9797

@@ -103,9 +103,3 @@ def __call__(self, m):
103103
pm.run(module)
104104
except Exception as e:
105105
print(f"caught exception: {e}")
106-
107-
# CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method
108-
try:
109-
custom_pass_that_fails.signal_pass_failure()
110-
except Exception as e:
111-
print(f"caught exception: {e}")

0 commit comments

Comments
 (0)