Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit d269268

Browse files
authored
[MLIR][Python] Add a python function to apply patterns with MlirOperation (#157487)
In llvm/llvm-project#94714, we add a python function `apply_patterns_and_fold_greedily` which accepts an `MlirModule` as the argument type. However, sometimes we want to apply patterns with an `MlirOperation` argument, and there is currently no python API to convert an `MlirOperation` to `MlirModule`. So here we overload this function `apply_patterns_and_fold_greedily` to do this (also a corresponding new C API `mlirApplyPatternsAndFoldGreedilyWithOp`)
1 parent 8411f12 commit d269268

File tree

3 files changed

+32
-10
lines changed

3 files changed

+32
-10
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
301301
MLIR_CAPI_EXPORTED void
302302
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
303303

304+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
305+
MlirOperation op, MlirFrozenRewritePatternSet patterns,
306+
MlirGreedyRewriteDriverConfig);
307+
304308
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
305309
MlirModule op, MlirFrozenRewritePatternSet patterns,
306310
MlirGreedyRewriteDriverConfig);

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
9999
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
100100
&PyFrozenRewritePatternSet::createFromCapsule);
101101
m.def(
102-
"apply_patterns_and_fold_greedily",
103-
[](MlirModule module, MlirFrozenRewritePatternSet set) {
104-
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105-
if (mlirLogicalResultIsFailure(status))
106-
// FIXME: Not sure this is the right error to throw here.
107-
throw nb::value_error("pattern application failed to converge");
108-
},
109-
"module"_a, "set"_a,
110-
"Applys the given patterns to the given module greedily while folding "
111-
"results.");
102+
"apply_patterns_and_fold_greedily",
103+
[](PyModule &module, MlirFrozenRewritePatternSet set) {
104+
auto status = mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
105+
if (mlirLogicalResultIsFailure(status))
106+
throw std::runtime_error("pattern application failed to converge");
107+
},
108+
"module"_a, "set"_a,
109+
"Applys the given patterns to the given module greedily while folding "
110+
"results.")
111+
.def(
112+
"apply_patterns_and_fold_greedily",
113+
[](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
114+
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
115+
op.getOperation(), set, {});
116+
if (mlirLogicalResultIsFailure(status))
117+
throw std::runtime_error(
118+
"pattern application failed to converge");
119+
},
120+
"op"_a, "set"_a,
121+
"Applys the given patterns to the given op greedily while folding "
122+
"results.");
112123
}

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
294294
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
295295
}
296296

297+
MlirLogicalResult
298+
mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
299+
MlirFrozenRewritePatternSet patterns,
300+
MlirGreedyRewriteDriverConfig) {
301+
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
302+
}
303+
297304
//===----------------------------------------------------------------------===//
298305
/// PDLPatternModule API
299306
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)