Skip to content

Commit aac4eb5

Browse files
authored
[MLIR][Python] Add a python function to apply patterns with MlirOperation (#157487)
In #94714, we add a python function `apply_patterns_and_fold_greedily` which accepts an `MlirModule` as the argument type. However, sometimes we want to apply patterns with an `MlirOperation` argument, and there is currently no python API to convert an `MlirOperation` to `MlirModule`. So here we overload this function `apply_patterns_and_fold_greedily` to do this (also a corresponding new C API `mlirApplyPatternsAndFoldGreedilyWithOp`)
1 parent 9f7877f commit aac4eb5

File tree

4 files changed

+68
-25
lines changed

4 files changed

+68
-25
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
301301
MLIR_CAPI_EXPORTED void
302302
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
303303

304+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
305+
MlirOperation op, MlirFrozenRewritePatternSet patterns,
306+
MlirGreedyRewriteDriverConfig);
307+
304308
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
305309
MlirModule op, MlirFrozenRewritePatternSet patterns,
306310
MlirGreedyRewriteDriverConfig);

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
9999
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
100100
&PyFrozenRewritePatternSet::createFromCapsule);
101101
m.def(
102-
"apply_patterns_and_fold_greedily",
103-
[](MlirModule module, MlirFrozenRewritePatternSet set) {
104-
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105-
if (mlirLogicalResultIsFailure(status))
106-
// FIXME: Not sure this is the right error to throw here.
107-
throw nb::value_error("pattern application failed to converge");
108-
},
109-
"module"_a, "set"_a,
110-
"Applys the given patterns to the given module greedily while folding "
111-
"results.");
102+
"apply_patterns_and_fold_greedily",
103+
[](PyModule &module, MlirFrozenRewritePatternSet set) {
104+
auto status = mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
105+
if (mlirLogicalResultIsFailure(status))
106+
throw std::runtime_error("pattern application failed to converge");
107+
},
108+
"module"_a, "set"_a,
109+
"Applys the given patterns to the given module greedily while folding "
110+
"results.")
111+
.def(
112+
"apply_patterns_and_fold_greedily",
113+
[](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
114+
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
115+
op.getOperation(), set, {});
116+
if (mlirLogicalResultIsFailure(status))
117+
throw std::runtime_error(
118+
"pattern application failed to converge");
119+
},
120+
"op"_a, "set"_a,
121+
"Applys the given patterns to the given op greedily while folding "
122+
"results.");
112123
}

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
294294
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
295295
}
296296

297+
MlirLogicalResult
298+
mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
299+
MlirFrozenRewritePatternSet patterns,
300+
MlirGreedyRewriteDriverConfig) {
301+
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
302+
}
303+
297304
//===----------------------------------------------------------------------===//
298305
/// PDLPatternModule API
299306
//===----------------------------------------------------------------------===//

mlir/test/python/integration/dialects/pdl.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,7 @@ def construct_and_print_in_module(f):
1616
print(module)
1717
return f
1818

19-
20-
# CHECK-LABEL: TEST: test_add_to_mul
21-
# CHECK: arith.muli
22-
@construct_and_print_in_module
23-
def test_add_to_mul(module_):
24-
index_type = IndexType.get()
25-
26-
# Create a test case.
27-
@module(sym_name="ir")
28-
def ir():
29-
@func.func(index_type, index_type)
30-
def add_func(a, b):
31-
return arith.addi(a, b)
32-
19+
def get_pdl_patterns():
3320
# Create a rewrite from add to mul. This will match
3421
# - operation name is arith.addi
3522
# - operands are index types.
@@ -61,7 +48,41 @@ def rew():
6148
# not yet captured Python side/has sharp edges. So best to construct the
6249
# module and PDL module in same scope.
6350
# FIXME: This should be made more robust.
64-
frozen = PDLModule(m).freeze()
51+
return PDLModule(m).freeze()
52+
53+
54+
# CHECK-LABEL: TEST: test_add_to_mul
55+
# CHECK: arith.muli
56+
@construct_and_print_in_module
57+
def test_add_to_mul(module_):
58+
index_type = IndexType.get()
59+
60+
# Create a test case.
61+
@module(sym_name="ir")
62+
def ir():
63+
@func.func(index_type, index_type)
64+
def add_func(a, b):
65+
return arith.addi(a, b)
66+
67+
frozen = get_pdl_patterns()
6568
# Could apply frozen pattern set multiple times.
6669
apply_patterns_and_fold_greedily(module_, frozen)
6770
return module_
71+
72+
73+
# CHECK-LABEL: TEST: test_add_to_mul_with_op
74+
# CHECK: arith.muli
75+
@construct_and_print_in_module
76+
def test_add_to_mul_with_op(module_):
77+
index_type = IndexType.get()
78+
79+
# Create a test case.
80+
@module(sym_name="ir")
81+
def ir():
82+
@func.func(index_type, index_type)
83+
def add_func(a, b):
84+
return arith.addi(a, b)
85+
86+
frozen = get_pdl_patterns()
87+
apply_patterns_and_fold_greedily(module_.operation, frozen)
88+
return module_

0 commit comments

Comments
 (0)