Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 28a845c

Browse files
authored
[mlir] Add PDL C & Python usage (#94714)
Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially. This also exposes greedy rewrite driver. Only way currently to define patterns is via PDL (just to keep small). The creation of the PDL pattern module could be improved to avoid folks potentially accessing the module used to construct it post construction. No ergonomic work done yet. --------- Signed-off-by: Jacques Pienaar <[email protected]>
1 parent 7a6977a commit 28a845c

File tree

4 files changed

+137
-0
lines changed

4 files changed

+137
-0
lines changed

IRModule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir-c/Diagnostics.h"
2323
#include "mlir-c/IR.h"
2424
#include "mlir-c/IntegerSet.h"
25+
#include "mlir-c/Transforms.h"
2526
#include "mlir/Bindings/Python/PybindAdaptors.h"
2627
#include "llvm/ADT/DenseMap.h"
2728

MainModule.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "Globals.h"
1212
#include "IRModule.h"
1313
#include "Pass.h"
14+
#include "Rewrite.h"
1415

1516
namespace py = pybind11;
1617
using namespace mlir;
@@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) {
116117
populateIRInterfaces(irModule);
117118
populateIRTypes(irModule);
118119

120+
auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
121+
populateRewriteSubmodule(rewriteModule);
122+
119123
// Define and populate PassManager submodule.
120124
auto passModule =
121125
m.def_submodule("passmanager", "MLIR Pass Management Bindings");

Rewrite.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "Rewrite.h"
10+
11+
#include "IRModule.h"
12+
#include "mlir-c/Bindings/Python/Interop.h"
13+
#include "mlir-c/Rewrite.h"
14+
#include "mlir/Config/mlir-config.h"
15+
16+
namespace py = pybind11;
17+
using namespace mlir;
18+
using namespace py::literals;
19+
using namespace mlir::python;
20+
21+
namespace {
22+
23+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
24+
/// Owning Wrapper around a PDLPatternModule.
25+
class PyPDLPatternModule {
26+
public:
27+
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
28+
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
29+
: module(other.module) {
30+
other.module.ptr = nullptr;
31+
}
32+
~PyPDLPatternModule() {
33+
if (module.ptr != nullptr)
34+
mlirPDLPatternModuleDestroy(module);
35+
}
36+
MlirPDLPatternModule get() { return module; }
37+
38+
private:
39+
MlirPDLPatternModule module;
40+
};
41+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
42+
43+
/// Owning Wrapper around a FrozenRewritePatternSet.
44+
class PyFrozenRewritePatternSet {
45+
public:
46+
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
47+
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
48+
: set(other.set) {
49+
other.set.ptr = nullptr;
50+
}
51+
~PyFrozenRewritePatternSet() {
52+
if (set.ptr != nullptr)
53+
mlirFrozenRewritePatternSetDestroy(set);
54+
}
55+
MlirFrozenRewritePatternSet get() { return set; }
56+
57+
pybind11::object getCapsule() {
58+
return py::reinterpret_steal<py::object>(
59+
mlirPythonFrozenRewritePatternSetToCapsule(get()));
60+
}
61+
62+
static pybind11::object createFromCapsule(pybind11::object capsule) {
63+
MlirFrozenRewritePatternSet rawPm =
64+
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
65+
if (rawPm.ptr == nullptr)
66+
throw py::error_already_set();
67+
return py::cast(PyFrozenRewritePatternSet(rawPm),
68+
py::return_value_policy::move);
69+
}
70+
71+
private:
72+
MlirFrozenRewritePatternSet set;
73+
};
74+
75+
} // namespace
76+
77+
/// Create the `mlir.rewrite` here.
78+
void mlir::python::populateRewriteSubmodule(py::module &m) {
79+
//----------------------------------------------------------------------------
80+
// Mapping of the top-level PassManager
81+
//----------------------------------------------------------------------------
82+
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83+
py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
84+
.def(py::init<>([](MlirModule module) {
85+
return mlirPDLPatternModuleFromModule(module);
86+
}),
87+
"module"_a, "Create a PDL module from the given module.")
88+
.def("freeze", [](PyPDLPatternModule &self) {
89+
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
90+
mlirRewritePatternSetFromPDLPatternModule(self.get())));
91+
});
92+
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
93+
py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
94+
py::module_local())
95+
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
96+
&PyFrozenRewritePatternSet::getCapsule)
97+
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
98+
&PyFrozenRewritePatternSet::createFromCapsule);
99+
m.def(
100+
"apply_patterns_and_fold_greedily",
101+
[](MlirModule module, MlirFrozenRewritePatternSet set) {
102+
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
103+
if (mlirLogicalResultIsFailure(status))
104+
// FIXME: Not sure this is the right error to throw here.
105+
throw py::value_error("pattern application failed to converge");
106+
},
107+
"module"_a, "set"_a,
108+
"Applys the given patterns to the given module greedily while folding "
109+
"results.");
110+
}

Rewrite.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
10+
#define MLIR_BINDINGS_PYTHON_REWRITE_H
11+
12+
#include "PybindUtils.h"
13+
14+
namespace mlir {
15+
namespace python {
16+
17+
void populateRewriteSubmodule(pybind11::module &m);
18+
19+
} // namespace python
20+
} // namespace mlir
21+
22+
#endif // MLIR_BINDINGS_PYTHON_REWRITE_H

0 commit comments

Comments
 (0)