Skip to content
14 changes: 14 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
MLIR_CAPI_EXPORTED MlirBlock
mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);

/// Returns the operation right after the current insertion point
/// of the rewriter. A null MlirOperation will be returned
// if the current insertion point is at the end of the block.
MLIR_CAPI_EXPORTED MlirOperation
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);

//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -310,6 +316,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);

//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//

/// Cast the PatternRewriter to a RewriterBase
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 15 additions & 1 deletion mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,21 @@ class PyFrozenRewritePatternSet {

/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
nb::class_<MlirPatternRewriter>(m, "PatternRewriter")
.def("ip", [](MlirPatternRewriter rewriter) {
MlirRewriterBase base = mlirPatternRewriterAsBase(rewriter);
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
MlirOperation owner = mlirBlockGetParentOperation(block);
auto ctx = PyMlirContext::forContext(mlirRewriterBaseGetContext(base))
->getRef();
if (mlirOperationIsNull(op)) {
auto parent = PyOperation::forOperation(ctx, owner);
return PyInsertionPoint(PyBlock(parent, block));
}

return PyInsertionPoint(*PyOperation::forOperation(ctx, op).get());
});
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if there is a good way to cast Mlir* CAPI types into Py* C++ classes. It seems that here we don't need to care too much about lifetime of blocks/operations (as long as the insertion point does not escape from the scope of the rewrite callback). 🤔

I'll try to define something like class PyPatternRewriter and see if that makes the code cleaner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 68264af.

//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
return wrap(unwrap(rewriter)->getBlock());
}

MlirOperation
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
mlir::RewriterBase *base = unwrap(rewriter);
mlir::Block *block = base->getInsertionBlock();
auto it = base->getInsertionPoint();
if (it == block->end()) {
return {nullptr};
}

return wrap(std::addressof(*it));
}

//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -316,6 +328,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
return {rewriter};
}

MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
91 changes: 89 additions & 2 deletions mlir/test/python/integration/dialects/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f


def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
Expand Down Expand Up @@ -121,8 +122,10 @@ def load_myint_dialect():


# This PDL pattern is to fold constant additions,
# i.e. add(constant0, constant1) -> constant2
# where constant2 = constant0 + constant1.
# including two patterns:
# 1. add(constant0, constant1) -> constant2
# where constant2 = constant0 + constant1;
# 2. add(x, 0) or add(0, x) -> x.
def get_pdl_pattern_fold():
m = Module.create()
i32 = IntegerType.get_signless(32)
Expand Down Expand Up @@ -237,3 +240,87 @@ def test_pdl_register_function_constraint(module_):
apply_patterns_and_fold_greedily(module_, frozen)

return module_


# This pattern is to expand constant to additions
# unless the constant is no more than 1,
# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
def get_pdl_pattern_expand():
m = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(m.body):

@pdl.pattern(benefit=1, sym_name="myint_constant_expand")
def pat():
t = pdl.TypeOp(i32)
cst = pdl.AttributeOp()
pdl.apply_native_constraint([], "is_one", [cst])
op0 = pdl.OperationOp(
name="myint.constant", attributes={"value": cst}, types=[t]
)

@pdl.rewrite()
def rew():
expanded = pdl.apply_native_rewrite(
[pdl.OperationType.get()], "expand", [cst]
)
pdl.ReplaceOp(op0, with_op=expanded)

def is_one(rewriter, results, values):
cst = values[0].value
return cst <= 1

def expand(rewriter, results, values):
cst = values[0].value
c1 = cst // 2
c2 = cst - c1
with rewriter.ip():
op1 = Operation.create(
"myint.constant",
results=[i32],
attributes={"value": IntegerAttr.get(i32, c1)},
)
op2 = Operation.create(
"myint.constant",
results=[i32],
attributes={"value": IntegerAttr.get(i32, c2)},
)
res = Operation.create(
"myint.add", results=[i32], operands=[op1.result, op2.result]
)
results.append(res)
Comment on lines 273 to 291
Copy link
Member Author

@PragmaTwice PragmaTwice Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function as an example of retrieving and using the insertion point of the rewriter.


pdl_module = PDLModule(m)
pdl_module.register_constraint_function("is_one", is_one)
pdl_module.register_rewrite_function("expand", expand)
return pdl_module.freeze()


# CHECK-LABEL: TEST: test_pdl_register_function_expand
# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
# CHECK: return %8 : i32
@construct_and_print_in_module
def test_pdl_register_function_expand(module_):
load_myint_dialect()

module_ = Module.parse(
"""
func.func @f() -> i32 {
%0 = "myint.constant"() { value = 5 }: () -> (i32)
return %0 : i32
}
"""
)

frozen = get_pdl_pattern_expand()
apply_patterns_and_fold_greedily(module_, frozen)

return module_