Skip to content

Commit b5daf76

Browse files
[MLIR][Python] Add bindings for PDL native rewrite function registering (llvm#159926)
In the MLIR Python bindings, we can currently use PDL to define simple patterns and then execute them with the greedy rewrite driver. However, when dealing with more complex patterns—such as constant folding for integer addition—we find that we need `apply_native_rewrite` to actually perform arithmetic (i.e., compute the sum of two constants). For example, consider the following PDL pseudocode: ```mlir pdl.pattern : benefit(1) { %a0 = pdl.attribute %a1 = pdl.attribute %c0 = pdl.operation "arith.constant" {value = %a0} %c1 = pdl.operation "arith.constant" {value = %a1} %op = pdl.operation "arith.addi"(%c0, %c1) %sum = pdl.apply_native_rewrite "addIntegers"(%a0, %a1) %new_cst = pdl.operation "arith.constant" {value = %sum} pdl.replace %op with %new_cst } ``` Here, `addIntegers` cannot be expressed in PDL alone—it requires a *native rewrite function*. This PR introduces a mechanism to support exactly that, allowing complex rewrite patterns to be expressed in Python and enabling many passes to be implemented directly in Python as well. As a test case, we defined two new operations (`myint.constant` and `myint.add`) in Python and implemented a constant-folding rewrite pattern for them. The core code looks like this: ```python m = Module.create() with InsertionPoint(m.body): @pdl.pattern(benefit=1, sym_name="myint_add_fold") def pat(): ... op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t]) @pdl.rewrite() def rew(): sum = pdl.apply_native_rewrite( [pdl.AttributeType.get()], "add_fold", [a0, a1] ) newOp = pdl.OperationOp( name="myint.constant", attributes={"value": sum}, types=[t] ) pdl.ReplaceOp(op0, with_op=newOp) def add_fold(rewriter, results, values): a0, a1 = values results.push_back(IntegerAttr.get(i32, a0.value + a1.value)) pdl_module = PDLModule(m) pdl_module.register_rewrite_function("add_fold", add_fold) ``` The idea is previously discussed in Discord #mlir-python channel with @makslevental. --------- Co-authored-by: Maksim Levental <[email protected]>
1 parent 79ad1bf commit b5daf76

File tree

5 files changed

+330
-6
lines changed

5 files changed

+330
-6
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void);
3737
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
3838
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
3939
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
40+
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
4041

4142
//===----------------------------------------------------------------------===//
4243
/// RewriterBase API inherited from OpBuilder
@@ -315,6 +316,8 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
315316

316317
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
317318
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
319+
DEFINE_C_API_STRUCT(MlirPDLValue, const void);
320+
DEFINE_C_API_STRUCT(MlirPDLResultList, void);
318321

319322
MLIR_CAPI_EXPORTED MlirPDLPatternModule
320323
mlirPDLPatternModuleFromModule(MlirModule op);
@@ -323,6 +326,55 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
323326

324327
MLIR_CAPI_EXPORTED MlirRewritePatternSet
325328
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
329+
330+
/// Cast the MlirPDLValue to an MlirValue.
331+
/// Return a null value if the cast fails, just like llvm::dyn_cast.
332+
MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value);
333+
334+
/// Cast the MlirPDLValue to an MlirType.
335+
/// Return a null value if the cast fails, just like llvm::dyn_cast.
336+
MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value);
337+
338+
/// Cast the MlirPDLValue to an MlirOperation.
339+
/// Return a null value if the cast fails, just like llvm::dyn_cast.
340+
MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value);
341+
342+
/// Cast the MlirPDLValue to an MlirAttribute.
343+
/// Return a null value if the cast fails, just like llvm::dyn_cast.
344+
MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value);
345+
346+
/// Push the MlirValue into the given MlirPDLResultList.
347+
MLIR_CAPI_EXPORTED void
348+
mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value);
349+
350+
/// Push the MlirType into the given MlirPDLResultList.
351+
MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results,
352+
MlirType value);
353+
354+
/// Push the MlirOperation into the given MlirPDLResultList.
355+
MLIR_CAPI_EXPORTED void
356+
mlirPDLResultListPushBackOperation(MlirPDLResultList results,
357+
MlirOperation value);
358+
359+
/// Push the MlirAttribute into the given MlirPDLResultList.
360+
MLIR_CAPI_EXPORTED void
361+
mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
362+
MlirAttribute value);
363+
364+
/// This function type is used as callbacks for PDL native rewrite functions.
365+
/// Input values can be accessed by `values` with its size `nValues`;
366+
/// output values can be added into `results` by `mlirPDLResultListPushBack*`
367+
/// APIs. And the return value indicates whether the rewrite succeeds.
368+
typedef MlirLogicalResult (*MlirPDLRewriteFunction)(
369+
MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
370+
MlirPDLValue *values, void *userData);
371+
372+
/// Register a rewrite function into the given PDL pattern module.
373+
/// `userData` will be provided as an argument to the rewrite function.
374+
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
375+
MlirPDLPatternModule pdlModule, MlirStringRef name,
376+
MlirPDLRewriteFunction rewriteFn, void *userData);
377+
326378
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
327379

328380
#undef DEFINE_C_API_STRUCT

mlir/include/mlir/IR/PDLPatternMatch.h.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public:
5353
/// value is not an instance of `T`.
5454
template <typename T,
5555
typename ResultT = std::conditional_t<
56-
std::is_convertible<T, bool>::value, T, std::optional<T>>>
56+
std::is_constructible_v<bool, T>, T, std::optional<T>>>
5757
ResultT dyn_cast() const {
5858
return isa<T>() ? castImpl<T>() : ResultT();
5959
}

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
#include "Rewrite.h"
1010

1111
#include "IRModule.h"
12+
#include "mlir-c/IR.h"
1213
#include "mlir-c/Rewrite.h"
14+
#include "mlir-c/Support.h"
1315
// clang-format off
1416
#include "mlir/Bindings/Python/Nanobind.h"
1517
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1618
// clang-format on
1719
#include "mlir/Config/mlir-config.h"
20+
#include "nanobind/nanobind.h"
1821

1922
namespace nb = nanobind;
2023
using namespace mlir;
@@ -24,6 +27,31 @@ using namespace mlir::python;
2427
namespace {
2528

2629
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
30+
static nb::object objectFromPDLValue(MlirPDLValue value) {
31+
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
32+
return nb::cast(v);
33+
if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
34+
return nb::cast(v);
35+
if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
36+
return nb::cast(v);
37+
if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
38+
return nb::cast(v);
39+
40+
throw std::runtime_error("unsupported PDL value type");
41+
}
42+
43+
// Convert the Python object to a boolean.
44+
// If it evaluates to False, treat it as success;
45+
// otherwise, treat it as failure.
46+
// Note that None is considered success.
47+
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
48+
if (obj.is_none())
49+
return mlirLogicalResultSuccess();
50+
51+
return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
52+
: mlirLogicalResultSuccess();
53+
}
54+
2755
/// Owning Wrapper around a PDLPatternModule.
2856
class PyPDLPatternModule {
2957
public:
@@ -38,6 +66,23 @@ class PyPDLPatternModule {
3866
}
3967
MlirPDLPatternModule get() { return module; }
4068

69+
void registerRewriteFunction(const std::string &name,
70+
const nb::callable &fn) {
71+
mlirPDLPatternModuleRegisterRewriteFunction(
72+
get(), mlirStringRefCreate(name.data(), name.size()),
73+
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
74+
size_t nValues, MlirPDLValue *values,
75+
void *userData) -> MlirLogicalResult {
76+
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
77+
std::vector<nb::object> args;
78+
args.reserve(nValues);
79+
for (size_t i = 0; i < nValues; ++i)
80+
args.push_back(objectFromPDLValue(values[i]));
81+
return logicalResultFromObject(f(rewriter, results, args));
82+
},
83+
fn.ptr());
84+
}
85+
4186
private:
4287
MlirPDLPatternModule module;
4388
};
@@ -78,10 +123,48 @@ class PyFrozenRewritePatternSet {
78123

79124
/// Create the `mlir.rewrite` here.
80125
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
126+
nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
81127
//----------------------------------------------------------------------------
82-
// Mapping of the top-level PassManager
128+
// Mapping of the PDLResultList and PDLModule
83129
//----------------------------------------------------------------------------
84130
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
131+
nb::class_<MlirPDLResultList>(m, "PDLResultList")
132+
.def(
133+
"append",
134+
[](MlirPDLResultList results, const PyValue &value) {
135+
mlirPDLResultListPushBackValue(results, value);
136+
},
137+
// clang-format off
138+
nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
139+
// clang-format on
140+
)
141+
.def(
142+
"append",
143+
[](MlirPDLResultList results, const PyOperation &op) {
144+
mlirPDLResultListPushBackOperation(results, op);
145+
},
146+
// clang-format off
147+
nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
148+
// clang-format on
149+
)
150+
.def(
151+
"append",
152+
[](MlirPDLResultList results, const PyType &type) {
153+
mlirPDLResultListPushBackType(results, type);
154+
},
155+
// clang-format off
156+
nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
157+
// clang-format on
158+
)
159+
.def(
160+
"append",
161+
[](MlirPDLResultList results, const PyAttribute &attr) {
162+
mlirPDLResultListPushBackAttribute(results, attr);
163+
},
164+
// clang-format off
165+
nb::sig("def append(self, " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
166+
// clang-format on
167+
);
85168
nb::class_<PyPDLPatternModule>(m, "PDLModule")
86169
.def(
87170
"__init__",
@@ -103,10 +186,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
103186
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
104187
// clang-format on
105188
"module"_a, "Create a PDL module from the given module.")
106-
.def("freeze", [](PyPDLPatternModule &self) {
107-
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
108-
mlirRewritePatternSetFromPDLPatternModule(self.get())));
109-
});
189+
.def(
190+
"freeze",
191+
[](PyPDLPatternModule &self) {
192+
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
193+
mlirRewritePatternSetFromPDLPatternModule(self.get())));
194+
},
195+
nb::keep_alive<0, 1>())
196+
.def(
197+
"register_rewrite_function",
198+
[](PyPDLPatternModule &self, const std::string &name,
199+
const nb::callable &fn) {
200+
self.registerRewriteFunction(name, fn);
201+
},
202+
nb::keep_alive<1, 3>());
110203
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
111204
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
112205
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "mlir/CAPI/Rewrite.h"
1414
#include "mlir/CAPI/Support.h"
1515
#include "mlir/CAPI/Wrap.h"
16+
#include "mlir/IR/Attributes.h"
17+
#include "mlir/IR/PDLPatternMatch.h.inc"
1618
#include "mlir/IR/PatternMatch.h"
1719
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
1820
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -301,6 +303,19 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
301303
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
302304
}
303305

306+
//===----------------------------------------------------------------------===//
307+
/// PatternRewriter API
308+
//===----------------------------------------------------------------------===//
309+
310+
inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
311+
assert(rewriter.ptr && "unexpected null rewriter");
312+
return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
313+
}
314+
315+
inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
316+
return {rewriter};
317+
}
318+
304319
//===----------------------------------------------------------------------===//
305320
/// PDLPatternModule API
306321
//===----------------------------------------------------------------------===//
@@ -331,4 +346,73 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
331346
op.ptr = nullptr;
332347
return wrap(m);
333348
}
349+
350+
inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
351+
assert(value.ptr && "unexpected null PDL value");
352+
return static_cast<const mlir::PDLValue *>(value.ptr);
353+
}
354+
355+
inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
356+
357+
inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
358+
assert(results.ptr && "unexpected null PDL results");
359+
return static_cast<mlir::PDLResultList *>(results.ptr);
360+
}
361+
362+
inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
363+
return {results};
364+
}
365+
366+
MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
367+
return wrap(unwrap(value)->dyn_cast<mlir::Value>());
368+
}
369+
370+
MlirType mlirPDLValueAsType(MlirPDLValue value) {
371+
return wrap(unwrap(value)->dyn_cast<mlir::Type>());
372+
}
373+
374+
MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
375+
return wrap(unwrap(value)->dyn_cast<mlir::Operation *>());
376+
}
377+
378+
MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
379+
return wrap(unwrap(value)->dyn_cast<mlir::Attribute>());
380+
}
381+
382+
void mlirPDLResultListPushBackValue(MlirPDLResultList results,
383+
MlirValue value) {
384+
unwrap(results)->push_back(unwrap(value));
385+
}
386+
387+
void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
388+
unwrap(results)->push_back(unwrap(value));
389+
}
390+
391+
void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
392+
MlirOperation value) {
393+
unwrap(results)->push_back(unwrap(value));
394+
}
395+
396+
void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
397+
MlirAttribute value) {
398+
unwrap(results)->push_back(unwrap(value));
399+
}
400+
401+
void mlirPDLPatternModuleRegisterRewriteFunction(
402+
MlirPDLPatternModule pdlModule, MlirStringRef name,
403+
MlirPDLRewriteFunction rewriteFn, void *userData) {
404+
unwrap(pdlModule)->registerRewriteFunction(
405+
unwrap(name),
406+
[userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
407+
ArrayRef<PDLValue> values) -> LogicalResult {
408+
std::vector<MlirPDLValue> mlirValues;
409+
mlirValues.reserve(values.size());
410+
for (auto &value : values) {
411+
mlirValues.push_back(wrap(&value));
412+
}
413+
return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
414+
mlirValues.size(), mlirValues.data(),
415+
userData));
416+
});
417+
}
334418
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

0 commit comments

Comments
 (0)