Skip to content

Commit 28d65b8

Browse files
committed
[MLIR][Python] Expose the insertion point of pattern rewriter
1 parent 129394e commit 28d65b8

File tree

4 files changed

+116
-3
lines changed

4 files changed

+116
-3
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
101101
MLIR_CAPI_EXPORTED MlirBlock
102102
mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);
103103

104+
MLIR_CAPI_EXPORTED MlirOperation
105+
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);
106+
104107
//===----------------------------------------------------------------------===//
105108
/// Block and operation creation/insertion/cloning
106109
//===----------------------------------------------------------------------===//
@@ -310,6 +313,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
310313
MlirModule op, MlirFrozenRewritePatternSet patterns,
311314
MlirGreedyRewriteDriverConfig);
312315

316+
//===----------------------------------------------------------------------===//
317+
/// PatternRewriter API
318+
//===----------------------------------------------------------------------===//
319+
320+
/// Cast the PatternRewriter to a RewriterBase
321+
MLIR_CAPI_EXPORTED MlirRewriterBase
322+
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
323+
313324
//===----------------------------------------------------------------------===//
314325
/// PDLPatternModule API
315326
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,21 @@ class PyFrozenRewritePatternSet {
143143

144144
/// Create the `mlir.rewrite` here.
145145
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
146-
nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
146+
nb::class_<MlirPatternRewriter>(m, "PatternRewriter")
147+
.def("ip", [](MlirPatternRewriter rewriter) {
148+
MlirRewriterBase base = mlirPatternRewriterAsBase(rewriter);
149+
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
150+
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
151+
MlirOperation owner = mlirBlockGetParentOperation(block);
152+
auto ctx = PyMlirContext::forContext(mlirRewriterBaseGetContext(base))
153+
->getRef();
154+
if (mlirOperationIsNull(op)) {
155+
auto parent = PyOperation::forOperation(ctx, owner);
156+
return PyInsertionPoint(PyBlock(parent, block));
157+
}
158+
159+
return PyInsertionPoint(*PyOperation::forOperation(ctx, op).get());
160+
});
147161
//----------------------------------------------------------------------------
148162
// Mapping of the PDLResultList and PDLModule
149163
//----------------------------------------------------------------------------

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,18 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
7070
return wrap(unwrap(rewriter)->getBlock());
7171
}
7272

73+
MlirOperation
74+
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
75+
mlir::RewriterBase *base = unwrap(rewriter);
76+
mlir::Block *block = base->getInsertionBlock();
77+
auto it = base->getInsertionPoint();
78+
if (it == block->end()) {
79+
return {nullptr};
80+
}
81+
82+
return wrap(std::addressof(*it));
83+
}
84+
7385
//===----------------------------------------------------------------------===//
7486
/// Block and operation creation/insertion/cloning
7587
//===----------------------------------------------------------------------===//
@@ -316,6 +328,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
316328
return {rewriter};
317329
}
318330

331+
MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
332+
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
333+
}
334+
319335
//===----------------------------------------------------------------------===//
320336
/// PDLPatternModule API
321337
//===----------------------------------------------------------------------===//

mlir/test/python/integration/dialects/pdl.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,10 @@ def load_myint_dialect():
121121

122122

123123
# This PDL pattern is to fold constant additions,
124-
# i.e. add(constant0, constant1) -> constant2
125-
# where constant2 = constant0 + constant1.
124+
# including two patterns:
125+
# 1. add(constant0, constant1) -> constant2
126+
# where constant2 = constant0 + constant1;
127+
# 2. add(x, 0) or add(0, x) -> x.
126128
def get_pdl_pattern_fold():
127129
m = Module.create()
128130
i32 = IntegerType.get_signless(32)
@@ -237,3 +239,73 @@ def test_pdl_register_function_constraint(module_):
237239
apply_patterns_and_fold_greedily(module_, frozen)
238240

239241
return module_
242+
243+
244+
# This pattern is to expand constant to additions
245+
# unless the constant is no more than 1,
246+
# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
247+
def get_pdl_pattern_expand():
248+
m = Module.create()
249+
i32 = IntegerType.get_signless(32)
250+
with InsertionPoint(m.body):
251+
252+
@pdl.pattern(benefit=1, sym_name="myint_constant_expand")
253+
def pat():
254+
t = pdl.TypeOp(i32)
255+
cst = pdl.AttributeOp()
256+
pdl.apply_native_constraint([], "is_one", [cst])
257+
op0 = pdl.OperationOp(name="myint.constant", attributes={"value": cst}, types=[t])
258+
259+
@pdl.rewrite()
260+
def rew():
261+
expanded = pdl.apply_native_rewrite([pdl.OperationType.get()], "expand", [cst])
262+
pdl.ReplaceOp(op0, with_op=expanded)
263+
264+
def is_one(rewriter, results, values):
265+
cst = values[0].value
266+
return cst <= 1
267+
268+
def expand(rewriter, results, values):
269+
cst = values[0].value
270+
c1 = cst // 2
271+
c2 = cst - c1
272+
with rewriter.ip():
273+
op1 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c1)})
274+
op2 = Operation.create("myint.constant", results=[i32], attributes={"value": IntegerAttr.get(i32, c2)})
275+
res = Operation.create("myint.add", results=[i32], operands=[op1.result, op2.result])
276+
results.append(res)
277+
278+
pdl_module = PDLModule(m)
279+
pdl_module.register_constraint_function("is_one", is_one)
280+
pdl_module.register_rewrite_function("expand", expand)
281+
return pdl_module.freeze()
282+
283+
284+
# CHECK-LABEL: TEST: test_pdl_register_function_expand
285+
# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
286+
# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
287+
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
288+
# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
289+
# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
290+
# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
291+
# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
292+
# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
293+
# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
294+
# CHECK: return %8 : i32
295+
@construct_and_print_in_module
296+
def test_pdl_register_function_expand(module_):
297+
load_myint_dialect()
298+
299+
module_ = Module.parse(
300+
"""
301+
func.func @f() -> i32 {
302+
%0 = "myint.constant"() { value = 5 }: () -> (i32)
303+
return %0 : i32
304+
}
305+
"""
306+
)
307+
308+
frozen = get_pdl_pattern_expand()
309+
apply_patterns_and_fold_greedily(module_, frozen)
310+
311+
return module_

0 commit comments

Comments
 (0)