Skip to content

Commit 4689bc2

Browse files
committed
[MLIR][Python] Support Python-defined rewrite patterns
1 parent 0c2e900 commit 4689bc2

File tree

4 files changed

+245
-5
lines changed

4 files changed

+245
-5
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
3838
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
3939
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
4040
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
41+
DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
4142

4243
//===----------------------------------------------------------------------===//
4344
/// RewriterBase API inherited from OpBuilder
@@ -324,6 +325,38 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
324325
MLIR_CAPI_EXPORTED MlirRewriterBase
325326
mlirPatternRewriterAsBase(MlirPatternRewriter rewriter);
326327

328+
//===----------------------------------------------------------------------===//
329+
/// RewritePattern API
330+
//===----------------------------------------------------------------------===//
331+
332+
typedef unsigned short MlirPatternBenefit;
333+
334+
typedef struct {
335+
void (*construct)(void *userData);
336+
void (*destruct)(void *userData);
337+
MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern,
338+
MlirOperation op,
339+
MlirPatternRewriter rewriter,
340+
void *userData);
341+
} MlirRewritePatternCallbacks;
342+
343+
MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
344+
MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
345+
MlirRewritePatternCallbacks callbacks, void *userData,
346+
size_t nGeneratedNames, MlirStringRef *generatedNames);
347+
348+
//===----------------------------------------------------------------------===//
349+
/// RewritePatternSet API
350+
//===----------------------------------------------------------------------===//
351+
352+
MLIR_CAPI_EXPORTED MlirRewritePatternSet
353+
mlirRewritePatternSetCreate(MlirContext context);
354+
355+
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
356+
357+
MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
358+
MlirRewritePattern pattern);
359+
327360
//===----------------------------------------------------------------------===//
328361
/// PDLPatternModule API
329362
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ class PyPatternRewriter {
4545
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
4646
}
4747

48+
void replaceOp(MlirOperation op, MlirOperation newOp) {
49+
mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
50+
}
51+
52+
void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
53+
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
54+
}
55+
4856
private:
4957
MlirRewriterBase base;
5058
PyMlirContextRef ctx;
@@ -165,13 +173,82 @@ class PyFrozenRewritePatternSet {
165173
MlirFrozenRewritePatternSet set;
166174
};
167175

176+
class PyRewritePatternSet {
177+
public:
178+
PyRewritePatternSet(MlirContext ctx)
179+
: set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
180+
~PyRewritePatternSet() { mlirRewritePatternSetDestroy(set); }
181+
182+
void add(MlirStringRef rootName, MlirPatternBenefit benefit,
183+
const nb::callable &matchAndRewrite) {
184+
MlirRewritePatternCallbacks callbacks;
185+
callbacks.construct = [](void *userData) {
186+
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
187+
};
188+
callbacks.destruct = [](void *userData) {
189+
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
190+
};
191+
callbacks.matchAndRewrite = [](MlirRewritePattern pattern, MlirOperation op,
192+
MlirPatternRewriter rewriter,
193+
void *userData) -> MlirLogicalResult {
194+
nb::handle f(static_cast<PyObject *>(userData));
195+
nb::object res = f(op, PyPatternRewriter(rewriter), pattern);
196+
return logicalResultFromObject(res);
197+
};
198+
MlirRewritePattern pattern = mlirOpRewritePattenCreate(
199+
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
200+
/* nGeneratedNames */ 0,
201+
/* generatedNames */ nullptr);
202+
mlirRewritePatternSetAdd(set, pattern);
203+
}
204+
205+
PyFrozenRewritePatternSet freeze() { return mlirFreezeRewritePattern(set); }
206+
207+
private:
208+
MlirRewritePatternSet set;
209+
MlirContext ctx;
210+
};
211+
168212
} // namespace
169213

170214
/// Create the `mlir.rewrite` here.
171215
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
216+
//----------------------------------------------------------------------------
217+
// Mapping of the PatternRewriter
218+
//----------------------------------------------------------------------------
172219
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
173220
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
174-
"The current insertion point of the PatternRewriter.");
221+
"The current insertion point of the PatternRewriter.")
222+
.def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
223+
MlirOperation newOp) { self.replaceOp(op, newOp); })
224+
.def("replace_op", [](PyPatternRewriter &self, MlirOperation op,
225+
const std::vector<MlirValue> &values) {
226+
self.replaceOp(op, values);
227+
});
228+
229+
//----------------------------------------------------------------------------
230+
// Mapping of the RewritePatternSet
231+
//----------------------------------------------------------------------------
232+
nb::class_<MlirRewritePattern>(m, "RewritePattern");
233+
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
234+
.def(
235+
"__init__",
236+
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
237+
new (&self) PyRewritePatternSet(context.get()->get());
238+
},
239+
"context"_a = nb::none())
240+
.def(
241+
"add",
242+
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
243+
unsigned benefit) {
244+
std::string opName =
245+
nb::cast<std::string>(root.attr("OPERATION_NAME"));
246+
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
247+
fn);
248+
},
249+
"root"_a, "fn"_a, "benefit"_a = 1)
250+
.def("freeze", &PyRewritePatternSet::freeze);
251+
175252
//----------------------------------------------------------------------------
176253
// Mapping of the PDLResultList and PDLModule
177254
//----------------------------------------------------------------------------
@@ -237,7 +314,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
237314
.def(
238315
"freeze",
239316
[](PyPDLPatternModule &self) {
240-
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
317+
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
241318
mlirRewritePatternSetFromPDLPatternModule(self.get())));
242319
},
243320
nb::keep_alive<0, 1>())

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/IR/PDLPatternMatch.h.inc"
1818
#include "mlir/IR/PatternMatch.h"
1919
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
20+
#include "mlir/Support/LLVM.h"
2021
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2122

2223
using namespace mlir;
@@ -270,9 +271,9 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
270271
/// RewritePatternSet and FrozenRewritePatternSet API
271272
//===----------------------------------------------------------------------===//
272273

273-
static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
274+
static inline mlir::RewritePatternSet *unwrap(MlirRewritePatternSet module) {
274275
assert(module.ptr && "unexpected null module");
275-
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
276+
return static_cast<mlir::RewritePatternSet *>(module.ptr);
276277
}
277278

278279
static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
@@ -291,7 +292,7 @@ wrap(mlir::FrozenRewritePatternSet *module) {
291292
}
292293

293294
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
294-
auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
295+
auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(op)));
295296
op.ptr = nullptr;
296297
return wrap(m);
297298
}
@@ -332,6 +333,86 @@ MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
332333
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
333334
}
334335

336+
//===----------------------------------------------------------------------===//
337+
/// RewritePattern API
338+
//===----------------------------------------------------------------------===//
339+
340+
inline const mlir::RewritePattern *unwrap(MlirRewritePattern pattern) {
341+
assert(pattern.ptr && "unexpected null pattern");
342+
return static_cast<const mlir::RewritePattern *>(pattern.ptr);
343+
}
344+
345+
inline MlirRewritePattern wrap(const mlir::RewritePattern *pattern) {
346+
return {pattern};
347+
}
348+
349+
namespace mlir {
350+
351+
class ExternalRewritePattern : public mlir::RewritePattern {
352+
public:
353+
ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
354+
StringRef rootName, PatternBenefit benefit,
355+
MLIRContext *context,
356+
ArrayRef<StringRef> generatedNames)
357+
: RewritePattern(rootName, benefit, context, generatedNames),
358+
callbacks(callbacks), userData(userData) {
359+
if (callbacks.construct)
360+
callbacks.construct(userData);
361+
}
362+
363+
~ExternalRewritePattern() {
364+
if (callbacks.destruct)
365+
callbacks.destruct(userData);
366+
}
367+
368+
LogicalResult matchAndRewrite(Operation *op,
369+
PatternRewriter &rewriter) const override {
370+
return unwrap(callbacks.matchAndRewrite(
371+
wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
372+
wrap(&rewriter), userData));
373+
}
374+
375+
private:
376+
MlirRewritePatternCallbacks callbacks;
377+
void *userData;
378+
};
379+
380+
} // namespace mlir
381+
382+
MlirRewritePattern mlirOpRewritePattenCreate(
383+
MlirStringRef rootName, MlirPatternBenefit benefit, MlirContext context,
384+
MlirRewritePatternCallbacks callbacks, void *userData,
385+
size_t nGeneratedNames, MlirStringRef *generatedNames) {
386+
std::vector<mlir::StringRef> generatedNamesVec;
387+
generatedNamesVec.reserve(nGeneratedNames);
388+
for (size_t i = 0; i < nGeneratedNames; ++i) {
389+
generatedNamesVec.push_back(unwrap(generatedNames[i]));
390+
}
391+
return wrap(new mlir::ExternalRewritePattern(
392+
callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
393+
unwrap(context), generatedNamesVec));
394+
}
395+
396+
//===----------------------------------------------------------------------===//
397+
/// RewritePatternSet API
398+
//===----------------------------------------------------------------------===//
399+
400+
MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
401+
return wrap(new mlir::RewritePatternSet(unwrap(context)));
402+
}
403+
404+
void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
405+
delete unwrap(set);
406+
}
407+
408+
void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
409+
MlirRewritePattern pattern) {
410+
std::unique_ptr<mlir::RewritePattern> patternPtr(
411+
const_cast<mlir::RewritePattern *>(unwrap(pattern)));
412+
pattern.ptr = nullptr;
413+
unwrap(set)->add(std::move(patternPtr));
414+
}
415+
335416
//===----------------------------------------------------------------------===//
336417
/// PDLPatternModule API
337418
//===----------------------------------------------------------------------===//

mlir/test/python/rewrite.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# RUN: %PYTHON %s 2>&1 | FileCheck %s
2+
3+
import gc, sys
4+
from mlir.ir import *
5+
from mlir.passmanager import *
6+
from mlir.dialects.builtin import ModuleOp
7+
from mlir.dialects import arith
8+
from mlir.rewrite import *
9+
10+
11+
def log(*args):
12+
print(*args, file=sys.stderr)
13+
sys.stderr.flush()
14+
15+
16+
def run(f):
17+
log("\nTEST:", f.__name__)
18+
f()
19+
gc.collect()
20+
assert Context._get_live_count() == 0
21+
22+
# CHECK-LABEL: TEST: testRewritePattern
23+
@run
24+
def testRewritePattern():
25+
def to_muli(op, rewriter, pattern):
26+
with rewriter.ip:
27+
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
28+
rewriter.replace_op(op, new_op.owner)
29+
30+
with Context():
31+
patterns = RewritePatternSet()
32+
patterns.add(arith.AddIOp, to_muli)
33+
frozen = patterns.freeze()
34+
35+
module = ModuleOp.parse(
36+
r"""
37+
module {
38+
func.func @add(%a: i64, %b: i64) -> i64 {
39+
%sum = arith.addi %a, %b : i64
40+
return %sum : i64
41+
}
42+
}
43+
"""
44+
)
45+
46+
apply_patterns_and_fold_greedily(module, frozen)
47+
# CHECK: %0 = arith.muli %arg0, %arg1 : i64
48+
# CHECK: return %0 : i64
49+
print(module)

0 commit comments

Comments
 (0)