Skip to content
14 changes: 14 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter);
MLIR_CAPI_EXPORTED MlirBlock
mlirRewriterBaseGetBlock(MlirRewriterBase rewriter);

/// Returns the operation right after the current insertion point
/// of the rewriter. A null MlirOperation will be returned
// if the current insertion point is at the end of the block.
MLIR_CAPI_EXPORTED MlirOperation
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter);

//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -310,6 +316,14 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);

//===----------------------------------------------------------------------===//
/// PatternRewriter API
//===----------------------------------------------------------------------===//

/// Cast the PatternRewriter to a RewriterBase
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,9 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
: refOperation(beforeOperationBase.getOperation().getRef()),
block((*refOperation)->getBlock()) {}

PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
: refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {}

void PyInsertionPoint::insert(PyOperationBase &operationBase) {
PyOperation &operation = operationBase.getOperation();
if (operation.isAttached())
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,8 @@ class PyInsertionPoint {
PyInsertionPoint(const PyBlock &block);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationBase &beforeOperationBase);
/// Creates an insertion point positioned before a reference operation.
PyInsertionPoint(PyOperationRef beforeOperationRef);

/// Shortcut to create an insertion point at the beginning of the block.
static PyInsertionPoint atBlockBegin(PyBlock &block);
Expand Down
34 changes: 31 additions & 3 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ using namespace mlir::python;

namespace {

class PyPatternRewriter {
public:
PyPatternRewriter(MlirPatternRewriter rewriter)
: rewriter(rewriter), base(mlirPatternRewriterAsBase(rewriter)),
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}

PyInsertionPoint getInsertionPoint() const {
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);

if (mlirOperationIsNull(op)) {
MlirOperation owner = mlirBlockGetParentOperation(block);
auto parent = PyOperation::forOperation(ctx, owner);
return PyInsertionPoint(PyBlock(parent, block));
}

return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}

private:
MlirPatternRewriter rewriter [[maybe_unused]];
MlirRewriterBase base;
PyMlirContextRef ctx;
};

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
Expand Down Expand Up @@ -84,7 +109,8 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(rewriter, results, objectsFromPDLValues(nValues, values)));
f(PyPatternRewriter(rewriter), results,
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
Expand All @@ -98,7 +124,8 @@ class PyPDLPatternModule {
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(rewriter, results, objectsFromPDLValues(nValues, values)));
f(PyPatternRewriter(rewriter), results,
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
Expand Down Expand Up @@ -143,7 +170,8 @@ class PyFrozenRewritePatternSet {

/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
nb::class_<PyPatternRewriter>(m, "PyPatternRewriter")
.def("ip", &PyPatternRewriter::getInsertionPoint);
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
return wrap(unwrap(rewriter)->getBlock());
}

MlirOperation
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
mlir::RewriterBase *base = unwrap(rewriter);
mlir::Block *block = base->getInsertionBlock();
auto it = base->getInsertionPoint();
if (it == block->end()) {
return {nullptr};
}

return wrap(std::addressof(*it));
}

//===----------------------------------------------------------------------===//
/// Block and operation creation/insertion/cloning
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -316,6 +328,10 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
return {rewriter};
}

MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}

//===----------------------------------------------------------------------===//
/// PDLPatternModule API
//===----------------------------------------------------------------------===//
Expand Down
91 changes: 89 additions & 2 deletions mlir/test/python/integration/dialects/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
print(module)
return f


def get_pdl_patterns():
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
Expand Down Expand Up @@ -121,8 +122,10 @@ def load_myint_dialect():


# This PDL pattern is to fold constant additions,
# i.e. add(constant0, constant1) -> constant2
# where constant2 = constant0 + constant1.
# including two patterns:
# 1. add(constant0, constant1) -> constant2
# where constant2 = constant0 + constant1;
# 2. add(x, 0) or add(0, x) -> x.
def get_pdl_pattern_fold():
m = Module.create()
i32 = IntegerType.get_signless(32)
Expand Down Expand Up @@ -237,3 +240,87 @@ def test_pdl_register_function_constraint(module_):
apply_patterns_and_fold_greedily(module_, frozen)

return module_


# This pattern is to expand constant to additions
# unless the constant is no more than 1,
# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
def get_pdl_pattern_expand():
m = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(m.body):

@pdl.pattern(benefit=1, sym_name="myint_constant_expand")
def pat():
t = pdl.TypeOp(i32)
cst = pdl.AttributeOp()
pdl.apply_native_constraint([], "is_one", [cst])
op0 = pdl.OperationOp(
name="myint.constant", attributes={"value": cst}, types=[t]
)

@pdl.rewrite()
def rew():
expanded = pdl.apply_native_rewrite(
[pdl.OperationType.get()], "expand", [cst]
)
pdl.ReplaceOp(op0, with_op=expanded)

def is_one(rewriter, results, values):
cst = values[0].value
return cst <= 1

def expand(rewriter, results, values):
cst = values[0].value
c1 = cst // 2
c2 = cst - c1
with rewriter.ip():
op1 = Operation.create(
"myint.constant",
results=[i32],
attributes={"value": IntegerAttr.get(i32, c1)},
)
op2 = Operation.create(
"myint.constant",
results=[i32],
attributes={"value": IntegerAttr.get(i32, c2)},
)
res = Operation.create(
"myint.add", results=[i32], operands=[op1.result, op2.result]
)
results.append(res)
Comment on lines 273 to 291
Copy link
Member Author

@PragmaTwice PragmaTwice Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function as an example of retrieving and using the insertion point of the rewriter.


pdl_module = PDLModule(m)
pdl_module.register_constraint_function("is_one", is_one)
pdl_module.register_rewrite_function("expand", expand)
return pdl_module.freeze()


# CHECK-LABEL: TEST: test_pdl_register_function_expand
# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
# CHECK: return %8 : i32
@construct_and_print_in_module
def test_pdl_register_function_expand(module_):
load_myint_dialect()

module_ = Module.parse(
"""
func.func @f() -> i32 {
%0 = "myint.constant"() { value = 5 }: () -> (i32)
return %0 : i32
}
"""
)

frozen = get_pdl_pattern_expand()
apply_patterns_and_fold_greedily(module_, frozen)

return module_