Skip to content

Commit 36aae61

Browse files
committed
sketch
1 parent 2ca028c commit 36aae61

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

mlir/include/mlir-c/Pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ struct MlirExternalPassCallbacks {
153153
/// The callback is called before the pass is run, allowing a chance to
154154
/// initialize any complex state necessary for running the pass.
155155
/// See Pass::initialize(MLIRContext *).
156-
MlirLogicalResult (*initialize)(MlirContext ctx, void *userData);
156+
MlirLogicalResult (*initialize)(void *userData, MlirContext ctx);
157157

158158
/// This callback is called when the pass is cloned.
159159
/// See Pass::clonePass().
160160
void *(*clone)(void *userData);
161161

162162
/// This callback is called when the pass is run.
163163
/// See Pass::runOnOperation().
164-
void (*run)(MlirOperation op, MlirExternalPass pass, void *userData);
164+
void (*run)(void *userData, MlirOperation op, MlirExternalPass pass);
165165
};
166166
typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks;
167167

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,55 @@ class PyPassManager {
5151

5252
} // namespace
5353

54+
class PythonPass {
55+
public:
56+
explicit PythonPass(py::object passObj) : passObj(std::move(passObj)) {}
57+
58+
void *construct() {}
59+
void *destruct() {}
60+
61+
MlirLogicalResult *initialize(MlirContext ctx) {}
62+
void *clone() {}
63+
64+
void run(MlirOperation op, MlirExternalPass pass) {}
65+
66+
py::object passObj;
67+
};
68+
69+
template <typename T, typename R>
70+
void *void_cast(R (T::*f)()) {
71+
union {
72+
R (T::*pf)();
73+
void *p;
74+
};
75+
pf = f;
76+
return p;
77+
}
78+
79+
template <typename classT, typename memberT>
80+
union u_ptm_cast {
81+
memberT pmember;
82+
void *pvoid;
83+
};
84+
85+
MlirExternalPassCallbacks makeTestExternalPassCallbacks() {
86+
return (MlirExternalPassCallbacks){
87+
reinterpret_cast<decltype(MlirExternalPassCallbacks::construct)>(
88+
void_cast(&PythonPass::construct)),
89+
reinterpret_cast<decltype(MlirExternalPassCallbacks::destruct)>(
90+
void_cast(&PythonPass::destruct)),
91+
nullptr,
92+
reinterpret_cast<decltype(MlirExternalPassCallbacks::clone)>(
93+
void_cast(&PythonPass::clone)),
94+
reinterpret_cast<decltype(MlirExternalPassCallbacks::run)>(
95+
u_ptm_cast<PythonPass,
96+
void (PythonPass::*)(MlirOperation, MlirExternalPass)>{
97+
&PythonPass::run}
98+
.pvoid),
99+
100+
};
101+
}
102+
54103
/// Create the `mlir.passmanager` here.
55104
void mlir::python::populatePassManagerSubmodule(py::module &m) {
56105
//----------------------------------------------------------------------------
@@ -114,6 +163,27 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
114163
"pipeline"_a,
115164
"Add textual pipeline elements to the pass manager. Throws a "
116165
"ValueError if the pipeline can't be parsed.")
166+
.def_static(
167+
"create_external_pass",
168+
[](py::object &passObj) {
169+
PythonPass pass = PythonPass(passObj);
170+
171+
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
172+
MlirTypeID passID =
173+
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
174+
MlirStringRef name =
175+
mlirStringRefCreateFromCString("TestExternalPass");
176+
MlirStringRef description = mlirStringRefCreateFromCString("");
177+
MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
178+
MlirStringRef argument =
179+
mlirStringRefCreateFromCString("test-external-pass");
180+
181+
auto cbs = makeTestExternalPassCallbacks();
182+
183+
MlirPass externalPass =
184+
mlirCreateExternalPass(passID, name, argument, description,
185+
emptyOpName, 0, NULL, cbs, &pass);
186+
})
117187
.def(
118188
"run",
119189
[](PyPassManager &passManager, PyOperationBase &op,
@@ -144,4 +214,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
144214
},
145215
"Print the textual representation for this PassManager, suitable to "
146216
"be passed to `parse` for round-tripping.");
217+
218+
py::class_<PythonPass>(m, "PythonPass", py::module_local())
219+
.def(py::init<>(
220+
[](py::object pass) { return PythonPass(std::move(pass)); }));
147221
}

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class ExternalPass : public Pass {
138138
protected:
139139
LogicalResult initialize(MLIRContext *ctx) override {
140140
if (callbacks.initialize)
141-
return unwrap(callbacks.initialize(wrap(ctx), userData));
141+
return unwrap(callbacks.initialize(userData, wrap(ctx)));
142142
return success();
143143
}
144144

@@ -149,7 +149,7 @@ class ExternalPass : public Pass {
149149
}
150150

151151
void runOnOperation() override {
152-
callbacks.run(wrap(getOperation()), wrap(this), userData);
152+
callbacks.run(userData, wrap(getOperation()), wrap(this));
153153
}
154154

155155
std::unique_ptr<Pass> clonePass() const override {

0 commit comments

Comments
 (0)