From 4689bc244c266bdabb7c8416ff06f9face681c45 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 00:03:10 +0800 Subject: [PATCH 01/15] [MLIR][Python] Support Python-defined rewrite patterns --- mlir/include/mlir-c/Rewrite.h | 33 +++++++++++ mlir/lib/Bindings/Python/Rewrite.cpp | 81 +++++++++++++++++++++++++- mlir/lib/CAPI/Transforms/Rewrite.cpp | 87 +++++++++++++++++++++++++++- mlir/test/python/rewrite.py | 49 ++++++++++++++++ 4 files changed, 245 insertions(+), 5 deletions(-) create mode 100644 mlir/test/python/rewrite.py diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 5dd285ee076c4..68bb112404170 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirPatternRewriter, void); +DEFINE_C_API_STRUCT(MlirRewritePattern, const void); //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder @@ -324,6 +325,38 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter); +//===----------------------------------------------------------------------===// +/// RewritePattern API +//===----------------------------------------------------------------------===// + +typedef unsigned short MlirPatternBenefit; + +typedef struct { + void (*construct)(void *userData); + void (*destruct)(void *userData); + MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern, + MlirOperation op, + MlirPatternRewriter rewriter, + void *userData); +} MlirRewritePatternCallbacks; + +MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( + MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context, + MlirRewritePatternCallbacks callbacks, void *userData, + size_t nGeneratedNames, MlirStringRef *generatedNames); + +//===----------------------------------------------------------------------===// +/// RewritePatternSet API +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirRewritePatternSet +mlirRewritePatternSetCreate(MlirContext context); + +MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set); + +MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, + MlirRewritePattern pattern); + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 9e3d9703c82e8..3740c59e62001 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -45,6 +45,14 @@ class PyPatternRewriter { return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } + void replaceOp(MlirOperation op, MlirOperation newOp) { + mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); + } + + void replaceOp(MlirOperation op, const std::vector &values) { + mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); + } + private: MlirRewriterBase base; PyMlirContextRef ctx; @@ -165,13 +173,82 @@ class PyFrozenRewritePatternSet { MlirFrozenRewritePatternSet set; }; +class PyRewritePatternSet { +public: + PyRewritePatternSet(MlirContext ctx) + : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} + ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); } + + void add(MlirStringRef rootName, MlirPatternBenefit benefit, + const nb::callable &matchAndRewrite) { + MlirRewritePatternCallbacks callbacks; + callbacks.construct = [](void *userData) { + nb::handle(static_cast(userData)).inc_ref(); + }; + callbacks.destruct = [](void *userData) { + nb::handle(static_cast(userData)).dec_ref(); + }; + callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op, + MlirPatternRewriter rewriter, + void *userData) -> MlirLogicalResult { + nb::handle f(static_cast(userData)); + nb::object res = f(op, PyPatternRewriter(rewriter), pattern); + return logicalResultFromObject(res); + }; + MlirRewritePattern pattern = mlirOpRewritePattenCreate( + rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), + /* nGeneratedNames */ 0, + /* generatedNames */ nullptr); + mlirRewritePatternSetAdd(set, pattern); + } + + PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); } + +private: + MlirRewritePatternSet set; + MlirContext ctx; +}; + } // namespace /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { + //---------------------------------------------------------------------------- + // Mapping of the PatternRewriter + //---------------------------------------------------------------------------- nb::class_(m, "PatternRewriter") .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, - "The current insertion point of the PatternRewriter."); + "The current insertion point of the PatternRewriter.") + .def("replace_op", [](PyPatternRewriter &self, MlirOperation op, + MlirOperation newOp) { self.replaceOp(op, newOp); }) + .def("replace_op", [](PyPatternRewriter &self, MlirOperation op, + const std::vector &values) { + self.replaceOp(op, values); + }); + + //---------------------------------------------------------------------------- + // Mapping of the RewritePatternSet + //---------------------------------------------------------------------------- + nb::class_(m, "RewritePattern"); + nb::class_(m, "RewritePatternSet") + .def( + "__init__", + [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { + new (&self) PyRewritePatternSet(context.get()->get()); + }, + "context"_a = nb::none()) + .def( + "add", + [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, + unsigned benefit) { + std::string opName = + nb::cast(root.attr("OPERATION_NAME")); + self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, + fn); + }, + "root"_a, "fn"_a, "benefit"_a = 1) + .def("freeze", &PyRewritePatternSet::freeze); + //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- @@ -237,7 +314,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def( "freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index c15a73b991f5d..f3430e2e78978 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/PDLPatternMatch.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -270,9 +271,9 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { +static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) { assert(module.ptr && "unexpected null module"); - return *(static_cast(module.ptr)); + return static_cast(module.ptr); } static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { @@ -291,7 +292,7 @@ wrap(mlir::FrozenRewritePatternSet *module) { } MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { - auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); + auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op))); op.ptr = nullptr; return wrap(m); } @@ -332,6 +333,86 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { return wrap(static_cast(unwrap(rewriter))); } +//===----------------------------------------------------------------------===// +/// RewritePattern API +//===----------------------------------------------------------------------===// + +inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) { + assert(pattern.ptr && "unexpected null pattern"); + return static_cast(pattern.ptr); +} + +inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) { + return {pattern}; +} + +namespace mlir { + +class ExternalRewritePattern : public mlir::RewritePattern { +public: + ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData, + StringRef rootName, PatternBenefit benefit, + MLIRContext *context, + ArrayRef generatedNames) + : RewritePattern(rootName, benefit, context, generatedNames), + callbacks(callbacks), userData(userData) { + if (callbacks.construct) + callbacks.construct(userData); + } + + ~ExternalRewritePattern() { + if (callbacks.destruct) + callbacks.destruct(userData); + } + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + return unwrap(callbacks.matchAndRewrite( + wrap(static_cast(this)), wrap(op), + wrap(&rewriter), userData)); + } + +private: + MlirRewritePatternCallbacks callbacks; + void *userData; +}; + +} // namespace mlir + +MlirRewritePattern mlirOpRewritePattenCreate( + MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context, + MlirRewritePatternCallbacks callbacks, void *userData, + size_t nGeneratedNames, MlirStringRef *generatedNames) { + std::vector generatedNamesVec; + generatedNamesVec.reserve(nGeneratedNames); + for (size_t i = 0; i < nGeneratedNames; ++i) { + generatedNamesVec.push_back(unwrap(generatedNames[i])); + } + return wrap(new mlir::ExternalRewritePattern( + callbacks, userData, unwrap(rootName), PatternBenefit(benefit), + unwrap(context), generatedNamesVec)); +} + +//===----------------------------------------------------------------------===// +/// RewritePatternSet API +//===----------------------------------------------------------------------===// + +MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) { + return wrap(new mlir::RewritePatternSet(unwrap(context))); +} + +void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) { + delete unwrap(set); +} + +void mlirRewritePatternSetAdd(MlirRewritePatternSet set, + MlirRewritePattern pattern) { + std::unique_ptr patternPtr( + const_cast(unwrap(pattern))); + pattern.ptr = nullptr; + unwrap(set)->add(std::move(patternPtr)); +} + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py new file mode 100644 index 0000000000000..6aed936f94d87 --- /dev/null +++ b/mlir/test/python/rewrite.py @@ -0,0 +1,49 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import gc, sys +from mlir.ir import * +from mlir.passmanager import * +from mlir.dialects.builtin import ModuleOp +from mlir.dialects import arith +from mlir.rewrite import * + + +def log(*args): + print(*args, file=sys.stderr) + sys.stderr.flush() + + +def run(f): + log("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + +# CHECK-LABEL: TEST: testRewritePattern +@run +def testRewritePattern(): + def to_muli(op, rewriter, pattern): + with rewriter.ip: + new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location) + rewriter.replace_op(op, new_op.owner) + + with Context(): + patterns = RewritePatternSet() + patterns.add(arith.AddIOp, to_muli) + frozen = patterns.freeze() + + module = ModuleOp.parse( + r""" + module { + func.func @add(%a: i64, %b: i64) -> i64 { + %sum = arith.addi %a, %b : i64 + return %sum : i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %0 = arith.muli %arg0, %arg1 : i64 + # CHECK: return %0 : i64 + print(module) From 61b87af618652e26f71400d7b238f1597a2ca364 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 00:56:46 +0800 Subject: [PATCH 02/15] format --- mlir/test/python/rewrite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index 6aed936f94d87..c7b6c1f19991e 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -19,6 +19,7 @@ def run(f): gc.collect() assert Context._get_live_count() == 0 + # CHECK-LABEL: TEST: testRewritePattern @run def testRewritePattern(): From 395627f8987187ca8f45dcefe6c9167b69a3f7d8 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 10:35:15 +0800 Subject: [PATCH 03/15] add docs for C API --- mlir/include/mlir-c/Rewrite.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 68bb112404170..cc021bcfba889 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -329,17 +329,28 @@ mlirPatternRewriterAsBase(MlirPatternRewriter rewriter); /// RewritePattern API //===----------------------------------------------------------------------===// +/// PatternBenefit represents the benefit of a pattern match. typedef unsigned short MlirPatternBenefit; +/// Callbacks to construct a rewrite pattern. typedef struct { + /// Optional constructor for the user data. + /// Set to nullptr to disable it. void (*construct)(void *userData); + /// Optional destructor for the user data. + /// Set to nullptr to disable it. void (*destruct)(void *userData); + /// The callback function to match against code rooted at the specified + /// operation, and perform the rewrite if the match is successful, + /// corresponding to RewritePattern::matchAndRewrite. MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData); } MlirRewritePatternCallbacks; +/// Create a rewrite pattern that matches the operation +/// with the given rootName, corresponding to mlir::OpRewritePattern. MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, @@ -349,11 +360,14 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( /// RewritePatternSet API //===----------------------------------------------------------------------===// +/// Create an empty MlirRewritePatternSet. MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context); +/// Destruct the given MlirRewritePatternSet. MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set); +/// Add the given MlirRewritePattern into a MlirRewritePatternSet. MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern); From 0ddd081a3eb27348c7b87058edcf8eb437c796a0 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 10:45:09 +0800 Subject: [PATCH 04/15] add more docs and fix some name --- mlir/include/mlir-c/Rewrite.h | 10 ++++++++-- mlir/lib/Bindings/Python/Rewrite.cpp | 6 +++++- mlir/lib/CAPI/Transforms/Rewrite.cpp | 13 +++++++------ 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index cc021bcfba889..66a9a5de1669d 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -303,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter); /// FrozenRewritePatternSet API //===----------------------------------------------------------------------===// +/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet. +/// Note that the ownership of the input set is transferred into the frozen set +/// after this call. MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet -mlirFreezeRewritePattern(MlirRewritePatternSet op); +mlirFreezeRewritePattern(MlirRewritePatternSet set); +/// Destroy the given MlirFrozenRewritePatternSet. MLIR_CAPI_EXPORTED void -mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); +mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set); MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp( MlirOperation op, MlirFrozenRewritePatternSet patterns, @@ -368,6 +372,8 @@ mlirRewritePatternSetCreate(MlirContext context); MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set); /// Add the given MlirRewritePattern into a MlirRewritePatternSet. +/// Note that the ownership of the pattern is transferred to the set after this +/// call. MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern); diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 3740c59e62001..9c99c6a4366b5 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -202,7 +202,11 @@ class PyRewritePatternSet { mlirRewritePatternSetAdd(set, pattern); } - PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); } + PyFrozenRewritePatternSet freeze() { + MlirRewritePatternSet s = set; + set.ptr = nullptr; + return mlirFreezeRewritePattern(s); + } private: MlirRewritePatternSet set; diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index f3430e2e78978..7e7a4f7715bb4 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -291,15 +291,16 @@ wrap(mlir::FrozenRewritePatternSet *module) { return {module}; } -MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { - auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op))); - op.ptr = nullptr; +MlirFrozenRewritePatternSet +mlirFreezeRewritePattern(MlirRewritePatternSet set) { + auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); + set.ptr = nullptr; return wrap(m); } -void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { - delete unwrap(op); - op.ptr = nullptr; +void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) { + delete unwrap(set); + set.ptr = nullptr; } MlirLogicalResult From da4bb8b560b3bc49d5064281ef407c618d24787c Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 11:28:45 +0800 Subject: [PATCH 05/15] add nb::sigs and python api docs --- mlir/lib/Bindings/Python/Rewrite.cpp | 55 ++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 9c99c6a4366b5..07559457f2f2f 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -53,6 +53,8 @@ class PyPatternRewriter { mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); } + void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); } + private: MlirRewriterBase base; PyMlirContextRef ctx; @@ -177,7 +179,10 @@ class PyRewritePatternSet { public: PyRewritePatternSet(MlirContext ctx) : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} - ~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); } + ~PyRewritePatternSet() { + if (set.ptr) + mlirRewritePatternSetDestroy(set); + } void add(MlirStringRef rootName, MlirPatternBenefit benefit, const nb::callable &matchAndRewrite) { @@ -220,15 +225,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the PatternRewriter //---------------------------------------------------------------------------- - nb::class_(m, "PatternRewriter") - .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, - "The current insertion point of the PatternRewriter.") - .def("replace_op", [](PyPatternRewriter &self, MlirOperation op, - MlirOperation newOp) { self.replaceOp(op, newOp); }) - .def("replace_op", [](PyPatternRewriter &self, MlirOperation op, - const std::vector &values) { - self.replaceOp(op, values); - }); + nb:: + class_(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter.") + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + MlirOperation newOp) { self.replaceOp(op, newOp); }, + "Replace an operation with a new operation.", + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") + ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ) + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + const std::vector &values) { + self.replaceOp(op, values); + }, + "Replace an operation with a list of values.", + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") + ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") + // clang-format on + ) + .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", + // clang-format off + nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ); //---------------------------------------------------------------------------- // Mapping of the RewritePatternSet @@ -250,8 +277,12 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, fn); }, - "root"_a, "fn"_a, "benefit"_a = 1) - .def("freeze", &PyRewritePatternSet::freeze); + "root"_a, "fn"_a, "benefit"_a = 1, + "Add a new rewrite pattern on the given root operation with the " + "callable as the matching and rewriting function and the given " + "benefit.") + .def("freeze", &PyRewritePatternSet::freeze, + "Freeze the pattern set into a frozen one."); //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule From 5333a6ef08a3286b83494089b568f0f4087f77a7 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 11:42:21 +0800 Subject: [PATCH 06/15] add more examples --- mlir/lib/CAPI/Transforms/Rewrite.cpp | 1 - mlir/test/python/rewrite.py | 27 +++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 7e7a4f7715bb4..d7c8e53f2bba6 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -17,7 +17,6 @@ #include "mlir/IR/PDLPatternMatch.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index c7b6c1f19991e..cbc3a4043f96c 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -28,9 +28,18 @@ def to_muli(op, rewriter, pattern): new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location) rewriter.replace_op(op, new_op.owner) + def constant_1_to_2(op, rewriter, pattern): + c = op.attributes["value"].value + if c != 1: + return True # failed to match + with rewriter.ip: + new_op = arith.constant(op.result.type, 2, loc=op.location) + rewriter.replace_op(op, [new_op]) + with Context(): patterns = RewritePatternSet() patterns.add(arith.AddIOp, to_muli) + patterns.add(arith.ConstantOp, constant_1_to_2) frozen = patterns.freeze() module = ModuleOp.parse( @@ -48,3 +57,21 @@ def to_muli(op, rewriter, pattern): # CHECK: %0 = arith.muli %arg0, %arg1 : i64 # CHECK: return %0 : i64 print(module) + + module = ModuleOp.parse( + r""" + module { + func.func @const() -> (i64, i64) { + %0 = arith.constant 1 : i64 + %1 = arith.constant 3 : i64 + return %0, %1 : i64, i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %c2_i64 = arith.constant 2 : i64 + # CHECK: %c3_i64 = arith.constant 3 : i64 + # CHECK: return %c2_i64, %c3_i64 : i64, i64 + print(module) From a57961fc66c12529e957086869e008e835b70a54 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 11:47:17 +0800 Subject: [PATCH 07/15] fix format --- mlir/test/python/rewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index cbc3a4043f96c..4537068a5b9d5 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -31,7 +31,7 @@ def to_muli(op, rewriter, pattern): def constant_1_to_2(op, rewriter, pattern): c = op.attributes["value"].value if c != 1: - return True # failed to match + return True # failed to match with rewriter.ip: new_op = arith.constant(op.result.type, 2, loc=op.location) rewriter.replace_op(op, [new_op]) From 43da9a2cbe6d074fa863e6d500b18fc1d0a62894 Mon Sep 17 00:00:00 2001 From: Twice Date: Fri, 10 Oct 2025 12:59:58 +0800 Subject: [PATCH 08/15] Update mlir/lib/Bindings/Python/Rewrite.cpp Co-authored-by: Maksim Levental --- mlir/lib/Bindings/Python/Rewrite.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 07559457f2f2f..c938360756f03 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -250,7 +250,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") // clang-format on - ) + nb::arg("op"), nb::arg("values")) .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", // clang-format off nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") From 75c2dd90ae36c92ef184dda1a27150f5ace66aaf Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 13:06:55 +0800 Subject: [PATCH 09/15] reformat nb::sigs and add nb::args --- mlir/lib/Bindings/Python/Rewrite.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index c938360756f03..078593955bf9c 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -233,10 +233,10 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { "replace_op", [](PyPatternRewriter &self, MlirOperation op, MlirOperation newOp) { self.replaceOp(op, newOp); }, - "Replace an operation with a new operation.", + "Replace an operation with a new operation.", nb::arg("op"), + nb::arg("new_op"), // clang-format off - nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") - ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") // clang-format on ) .def( @@ -245,13 +245,14 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { const std::vector &values) { self.replaceOp(op, values); }, - "Replace an operation with a list of values.", + "Replace an operation with a list of values.", nb::arg("op"), + nb::arg("values"), // clang-format off - nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") - ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") // clang-format on - nb::arg("op"), nb::arg("values")) + ) .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", + nb::arg("op"), // clang-format off nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") // clang-format on From 64d98e42960545330ee4842f8d81d12664f12784 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 13:09:16 +0800 Subject: [PATCH 10/15] remove log() --- mlir/test/python/rewrite.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index 4537068a5b9d5..546a4fb720a98 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -8,13 +8,8 @@ from mlir.rewrite import * -def log(*args): - print(*args, file=sys.stderr) - sys.stderr.flush() - - def run(f): - log("\nTEST:", f.__name__) + print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 From 528f0e33e6ad4b31a49930b3089530fc181c813f Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Fri, 10 Oct 2025 14:22:33 +0800 Subject: [PATCH 11/15] remove patternbenefit typedef --- mlir/include/mlir-c/Rewrite.h | 5 +---- mlir/lib/Bindings/Python/Rewrite.cpp | 2 +- mlir/lib/CAPI/Transforms/Rewrite.cpp | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 66a9a5de1669d..2db1d84cd1d89 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -333,9 +333,6 @@ mlirPatternRewriterAsBase(MlirPatternRewriter rewriter); /// RewritePattern API //===----------------------------------------------------------------------===// -/// PatternBenefit represents the benefit of a pattern match. -typedef unsigned short MlirPatternBenefit; - /// Callbacks to construct a rewrite pattern. typedef struct { /// Optional constructor for the user data. @@ -356,7 +353,7 @@ typedef struct { /// Create a rewrite pattern that matches the operation /// with the given rootName, corresponding to mlir::OpRewritePattern. MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( - MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context, + MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames); diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 078593955bf9c..1d44da1cd94dd 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -184,7 +184,7 @@ class PyRewritePatternSet { mlirRewritePatternSetDestroy(set); } - void add(MlirStringRef rootName, MlirPatternBenefit benefit, + void add(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite) { MlirRewritePatternCallbacks callbacks; callbacks.construct = [](void *userData) { diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index d7c8e53f2bba6..5d79a66d1d033 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -380,7 +380,7 @@ class ExternalRewritePattern : public mlir::RewritePattern { } // namespace mlir MlirRewritePattern mlirOpRewritePattenCreate( - MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context, + MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) { std::vector generatedNamesVec; From 64c9d307797696d4fcac3896c707e64ca34940b0 Mon Sep 17 00:00:00 2001 From: Twice Date: Sat, 11 Oct 2025 10:01:05 +0800 Subject: [PATCH 12/15] Update mlir/lib/CAPI/Transforms/Rewrite.cpp Co-authored-by: Maksim Levental --- mlir/lib/CAPI/Transforms/Rewrite.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 5d79a66d1d033..1947e7ee9d7fa 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -338,7 +338,7 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { //===----------------------------------------------------------------------===// inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) { - assert(pattern.ptr && "unexpected null pattern"); + assert(pattern.ptr && "expected non-null pattern"); return static_cast(pattern.ptr); } From 14d1fb49c7bc6628aa93315043efcc362b15a052 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 11 Oct 2025 10:05:47 +0800 Subject: [PATCH 13/15] remove context count check --- mlir/test/python/rewrite.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index 546a4fb720a98..c904c8c5c65bb 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -1,6 +1,5 @@ # RUN: %PYTHON %s 2>&1 | FileCheck %s -import gc, sys from mlir.ir import * from mlir.passmanager import * from mlir.dialects.builtin import ModuleOp @@ -11,8 +10,6 @@ def run(f): print("\nTEST:", f.__name__) f() - gc.collect() - assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testRewritePattern From 3cc8cfed1f79c716871516951c1e1665ed667cf5 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 11 Oct 2025 10:14:04 +0800 Subject: [PATCH 14/15] remove pattern parameter in callback --- mlir/lib/Bindings/Python/Rewrite.cpp | 4 ++-- mlir/test/python/rewrite.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 1d44da1cd94dd..d506b7fc9bc7b 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -193,11 +193,11 @@ class PyRewritePatternSet { callbacks.destruct = [](void *userData) { nb::handle(static_cast(userData)).dec_ref(); }; - callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op, + callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData) -> MlirLogicalResult { nb::handle f(static_cast(userData)); - nb::object res = f(op, PyPatternRewriter(rewriter), pattern); + nb::object res = f(op, PyPatternRewriter(rewriter)); return logicalResultFromObject(res); }; MlirRewritePattern pattern = mlirOpRewritePattenCreate( diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index c904c8c5c65bb..acf7db23db914 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -15,12 +15,12 @@ def run(f): # CHECK-LABEL: TEST: testRewritePattern @run def testRewritePattern(): - def to_muli(op, rewriter, pattern): + def to_muli(op, rewriter): with rewriter.ip: new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location) rewriter.replace_op(op, new_op.owner) - def constant_1_to_2(op, rewriter, pattern): + def constant_1_to_2(op, rewriter): c = op.attributes["value"].value if c != 1: return True # failed to match From 65d914742acf2702aabe8188fad94ab61dcad1d6 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 11 Oct 2025 10:53:00 +0800 Subject: [PATCH 15/15] use DEFINE_C_API_PTR_METHODS --- mlir/include/mlir/CAPI/Rewrite.h | 2 ++ mlir/lib/CAPI/Transforms/Rewrite.cpp | 18 ------------------ 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h index 1038c0a575cf2..8cd51edf0837b 100644 --- a/mlir/include/mlir/CAPI/Rewrite.h +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -20,5 +20,7 @@ #include "mlir/IR/PatternMatch.h" DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase) +DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern) +DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet) #endif // MLIR_CAPIREWRITER_H diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 1947e7ee9d7fa..70dee598c9535 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -270,15 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return static_cast(module.ptr); -} - -static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { - return {module}; -} - static inline mlir::FrozenRewritePatternSet * unwrap(MlirFrozenRewritePatternSet module) { assert(module.ptr && "unexpected null module"); @@ -337,15 +328,6 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { /// RewritePattern API //===----------------------------------------------------------------------===// -inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) { - assert(pattern.ptr && "expected non-null pattern"); - return static_cast(pattern.ptr); -} - -inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) { - return {pattern}; -} - namespace mlir { class ExternalRewritePattern : public mlir::RewritePattern {