@@ -53,6 +53,8 @@ class PyPatternRewriter {
5353 mlirRewriterBaseReplaceOpWithValues (base, op, values.size (), values.data ());
5454 }
5555
56+ void eraseOp (MlirOperation op) { mlirRewriterBaseEraseOp (base, op); }
57+
5658private:
5759 MlirRewriterBase base;
5860 PyMlirContextRef ctx;
@@ -177,7 +179,10 @@ class PyRewritePatternSet {
177179public:
178180 PyRewritePatternSet (MlirContext ctx)
179181 : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
180- ~PyRewritePatternSet () { mlirRewritePatternSetDestroy (set); }
182+ ~PyRewritePatternSet () {
183+ if (set.ptr )
184+ mlirRewritePatternSetDestroy (set);
185+ }
181186
182187 void add (MlirStringRef rootName, MlirPatternBenefit benefit,
183188 const nb::callable &matchAndRewrite) {
@@ -220,15 +225,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
220225 // ----------------------------------------------------------------------------
221226 // Mapping of the PatternRewriter
222227 // ----------------------------------------------------------------------------
223- nb::class_<PyPatternRewriter>(m, " PatternRewriter" )
224- .def_prop_ro (" ip" , &PyPatternRewriter::getInsertionPoint,
225- " The current insertion point of the PatternRewriter." )
226- .def (" replace_op" , [](PyPatternRewriter &self, MlirOperation op,
227- MlirOperation newOp) { self.replaceOp (op, newOp); })
228- .def (" replace_op" , [](PyPatternRewriter &self, MlirOperation op,
229- const std::vector<MlirValue> &values) {
230- self.replaceOp (op, values);
231- });
228+ nb::
229+ class_<PyPatternRewriter>(m, " PatternRewriter" )
230+ .def_prop_ro (" ip" , &PyPatternRewriter::getInsertionPoint,
231+ " The current insertion point of the PatternRewriter." )
232+ .def (
233+ " replace_op" ,
234+ [](PyPatternRewriter &self, MlirOperation op,
235+ MlirOperation newOp) { self.replaceOp (op, newOp); },
236+ " Replace an operation with a new operation." ,
237+ // clang-format off
238+ nb::sig (" def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME (" ir.Operation" )
239+ " , new_op: " MAKE_MLIR_PYTHON_QUALNAME (" ir.Operation" ) " ) -> None" )
240+ // clang-format on
241+ )
242+ .def (
243+ " replace_op" ,
244+ [](PyPatternRewriter &self, MlirOperation op,
245+ const std::vector<MlirValue> &values) {
246+ self.replaceOp (op, values);
247+ },
248+ " Replace an operation with a list of values." ,
249+ // clang-format off
250+ nb::sig (" def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME (" ir.Operation" )
251+ " , values: list[" MAKE_MLIR_PYTHON_QUALNAME (" ir.Value" ) " ]) -> None" )
252+ // clang-format on
253+ )
254+ .def (" erase_op" , &PyPatternRewriter::eraseOp, " Erase an operation." ,
255+ // clang-format off
256+ nb::sig (" def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME (" ir.Operation" ) " ) -> None" )
257+ // clang-format on
258+ );
232259
233260 // ----------------------------------------------------------------------------
234261 // Mapping of the RewritePatternSet
@@ -250,8 +277,12 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
250277 self.add (mlirStringRefCreate (opName.data (), opName.size ()), benefit,
251278 fn);
252279 },
253- " root" _a, " fn" _a, " benefit" _a = 1 )
254- .def (" freeze" , &PyRewritePatternSet::freeze);
280+ " root" _a, " fn" _a, " benefit" _a = 1 ,
281+ " Add a new rewrite pattern on the given root operation with the "
282+ " callable as the matching and rewriting function and the given "
283+ " benefit." )
284+ .def (" freeze" , &PyRewritePatternSet::freeze,
285+ " Freeze the pattern set into a frozen one." );
255286
256287 // ----------------------------------------------------------------------------
257288 // Mapping of the PDLResultList and PDLModule
0 commit comments