Skip to content

Commit da4bb8b

Browse files
committed
add nb::sigs and python api docs
1 parent 0ddd081 commit da4bb8b

File tree

1 file changed

+43
-12
lines changed

1 file changed

+43
-12
lines changed

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class PyPatternRewriter {
5353
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
5454
}
5555

56+
void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
57+
5658
private:
5759
MlirRewriterBase base;
5860
PyMlirContextRef ctx;
@@ -177,7 +179,10 @@ class PyRewritePatternSet {
177179
public:
178180
PyRewritePatternSet(MlirContext ctx)
179181
: set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
180-
~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
182+
~PyRewritePatternSet() {
183+
if (set.ptr)
184+
mlirRewritePatternSetDestroy(set);
185+
}
181186

182187
void add(MlirStringRef rootName, MlirPatternBenefit benefit,
183188
const nb::callable &matchAndRewrite) {
@@ -220,15 +225,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
220225
//----------------------------------------------------------------------------
221226
// Mapping of the PatternRewriter
222227
//----------------------------------------------------------------------------
223-
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
224-
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
225-
"The current insertion point of the PatternRewriter.")
226-
.def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
227-
MlirOperation newOp) { self.replaceOp(op, newOp); })
228-
.def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
229-
const std::vector<MlirValue> &values) {
230-
self.replaceOp(op, values);
231-
});
228+
nb::
229+
class_<PyPatternRewriter>(m, "PatternRewriter")
230+
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
231+
"The current insertion point of the PatternRewriter.")
232+
.def(
233+
"replace_op",
234+
[](PyPatternRewriter &self, MlirOperation op,
235+
MlirOperation newOp) { self.replaceOp(op, newOp); },
236+
"Replace an operation with a new operation.",
237+
// clang-format off
238+
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
239+
", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
240+
// clang-format on
241+
)
242+
.def(
243+
"replace_op",
244+
[](PyPatternRewriter &self, MlirOperation op,
245+
const std::vector<MlirValue> &values) {
246+
self.replaceOp(op, values);
247+
},
248+
"Replace an operation with a list of values.",
249+
// clang-format off
250+
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")
251+
", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
252+
// clang-format on
253+
)
254+
.def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
255+
// clang-format off
256+
nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
257+
// clang-format on
258+
);
232259

233260
//----------------------------------------------------------------------------
234261
// Mapping of the RewritePatternSet
@@ -250,8 +277,12 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
250277
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
251278
fn);
252279
},
253-
"root"_a, "fn"_a, "benefit"_a = 1)
254-
.def("freeze", &PyRewritePatternSet::freeze);
280+
"root"_a, "fn"_a, "benefit"_a = 1,
281+
"Add a new rewrite pattern on the given root operation with the "
282+
"callable as the matching and rewriting function and the given "
283+
"benefit.")
284+
.def("freeze", &PyRewritePatternSet::freeze,
285+
"Freeze the pattern set into a frozen one.");
255286

256287
//----------------------------------------------------------------------------
257288
// Mapping of the PDLResultList and PDLModule

0 commit comments

Comments
 (0)