Skip to content

Commit d5055a7

Browse files
committed
add test case
1 parent 7556ca2 commit d5055a7

File tree

5 files changed

+106
-13
lines changed

5 files changed

+106
-13
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
301301
MLIR_CAPI_EXPORTED void
302302
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
303303

304+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyForOp(
305+
MlirOperation op, MlirFrozenRewritePatternSet patterns,
306+
MlirGreedyRewriteDriverConfig);
307+
304308
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
305309
MlirModule op, MlirFrozenRewritePatternSet patterns,
306310
MlirGreedyRewriteDriverConfig);

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
#include "Pass.h"
1010

1111
#include "IRModule.h"
12-
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1312
#include "mlir-c/Pass.h"
14-
#include "mlir-c/Support.h"
1513
#include "mlir/Bindings/Python/Nanobind.h"
1614
#include "nanobind/trampoline.h"
17-
#include "llvm/Support/ErrorHandling.h"
15+
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1816

1917
namespace nb = nanobind;
2018
using namespace nb::literals;

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
9999
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
100100
&PyFrozenRewritePatternSet::createFromCapsule);
101101
m.def(
102-
"apply_patterns_and_fold_greedily",
103-
[](MlirModule module, MlirFrozenRewritePatternSet set) {
104-
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105-
if (mlirLogicalResultIsFailure(status))
106-
// FIXME: Not sure this is the right error to throw here.
107-
throw nb::value_error("pattern application failed to converge");
108-
},
109-
"module"_a, "set"_a,
110-
"Applys the given patterns to the given module greedily while folding "
111-
"results.");
102+
"apply_patterns_and_fold_greedily",
103+
[](MlirModule module, MlirFrozenRewritePatternSet set) {
104+
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105+
if (mlirLogicalResultIsFailure(status))
106+
// FIXME: Not sure this is the right error to throw here.
107+
throw nb::value_error("pattern application failed to converge");
108+
},
109+
"module"_a, "set"_a,
110+
"Applys the given patterns to the given module greedily while folding "
111+
"results.")
112+
.def(
113+
"apply_patterns_and_fold_greedily_for_op",
114+
[](MlirOperation op, MlirFrozenRewritePatternSet set) {
115+
auto status = mlirApplyPatternsAndFoldGreedilyForOp(op, set, {});
116+
if (mlirLogicalResultIsFailure(status))
117+
// FIXME: Not sure this is the right error to throw here.
118+
throw nb::value_error("pattern application failed to converge");
119+
},
120+
"op"_a, "set"_a,
121+
"Applys the given patterns to the given op greedily while folding "
122+
"results.");
112123
}

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
294294
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
295295
}
296296

297+
MlirLogicalResult
298+
mlirApplyPatternsAndFoldGreedilyForOp(MlirOperation op,
299+
MlirFrozenRewritePatternSet patterns,
300+
MlirGreedyRewriteDriverConfig) {
301+
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
302+
}
303+
297304
//===----------------------------------------------------------------------===//
298305
/// PDLPatternModule API
299306
//===----------------------------------------------------------------------===//

mlir/test/python/pass.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# RUN: %PYTHON %s 2>&1 | FileCheck %s
2+
3+
import gc, sys
4+
from mlir.ir import *
5+
from mlir.passmanager import *
6+
from mlir.dialects.builtin import ModuleOp
7+
from mlir.dialects import pdl
8+
from mlir.rewrite import *
9+
10+
def log(*args):
11+
print(*args, file=sys.stderr)
12+
sys.stderr.flush()
13+
14+
15+
def run(f):
16+
log("\nTEST:", f.__name__)
17+
f()
18+
gc.collect()
19+
assert Context._get_live_count() == 0
20+
21+
def make_pdl_module():
22+
with Location.unknown():
23+
pdl_module = Module.create()
24+
with InsertionPoint(pdl_module.body):
25+
# Change all arith.addi with index types to arith.muli.
26+
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
27+
def pat():
28+
# Match arith.addi with index types.
29+
index_type = pdl.TypeOp(IndexType.get())
30+
operand0 = pdl.OperandOp(index_type)
31+
operand1 = pdl.OperandOp(index_type)
32+
op0 = pdl.OperationOp(
33+
name="arith.addi", args=[operand0, operand1], types=[index_type]
34+
)
35+
36+
# Replace the matched op with arith.muli.
37+
@pdl.rewrite()
38+
def rew():
39+
newOp = pdl.OperationOp(
40+
name="arith.muli", args=[operand0, operand1], types=[index_type]
41+
)
42+
pdl.ReplaceOp(op0, with_op=newOp)
43+
44+
return pdl_module
45+
46+
# CHECK-LABEL: TEST: testCustomPass
47+
@run
48+
def testCustomPass():
49+
with Context():
50+
pdl_module = make_pdl_module()
51+
52+
class CustomPass(Pass):
53+
def __init__(self):
54+
super().__init__("CustomPass", op_name="builtin.module")
55+
def run(self, m):
56+
frozen = PDLModule(pdl_module).freeze()
57+
apply_patterns_and_fold_greedily_for_op(m, frozen)
58+
59+
module = ModuleOp.parse(r"""
60+
module {
61+
func.func @add(%a: index, %b: index) -> index {
62+
%sum = arith.addi %a, %b : index
63+
return %sum : index
64+
}
65+
}
66+
""")
67+
68+
# CHECK-LABEL: Dump After CustomPass
69+
# CHECK: arith.muli
70+
pm = PassManager('any')
71+
pm.enable_ir_printing()
72+
pm.add(CustomPass())
73+
pm.run(module)

0 commit comments

Comments
 (0)