diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 2db1d84cd1d89..f42aaff4d3c19 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -39,6 +39,7 @@ DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirPatternRewriter, void); DEFINE_C_API_STRUCT(MlirRewritePattern, const void); +DEFINE_C_API_STRUCT(MlirRewriterBaseListener, void); //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder @@ -48,6 +49,15 @@ DEFINE_C_API_STRUCT(MlirRewritePattern, const void); MLIR_CAPI_EXPORTED MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter); +/// Get the listener of the rewriter. +MLIR_CAPI_EXPORTED MlirRewriterBaseListener +mlirRewriterBaseGetListener(MlirRewriterBase rewriter); + +/// Notify the listener that the specified operation was inserted. +MLIR_CAPI_EXPORTED void mlirRewriterBaseListenerNotifyOperationInserted( + MlirRewriterBaseListener listener, MlirOperation op, + MlirOperation insertionPoint); + //===----------------------------------------------------------------------===// /// Insertion points methods //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h index 9c96d354d4fc9..357f2ae5a4418 100644 --- a/mlir/include/mlir/CAPI/Rewrite.h +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -26,6 +26,7 @@ DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet) DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet, mlir::FrozenRewritePatternSet) DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter) +DEFINE_C_API_PTR_METHODS(MlirRewriterBaseListener, mlir::RewriterBase::Listener) #if MLIR_ENABLE_PDL_IN_PATTERNMATCH DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7b1710656243a..66d008445227b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -810,11 +810,11 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { } void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, - nb::object insertionPoint, - nb::object location) { + nb::object insertionPoint, nb::object location, + nb::object listener) { auto &stack = getStack(); stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), - std::move(location)); + std::move(location), std::move(listener)); // If the new stack has more than one entry and the context of the new top // entry matches the previous, copy the insertionPoint and location from the // previous entry if missing from the new top entry. @@ -827,6 +827,8 @@ void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, current.insertionPoint = prev.insertionPoint; if (!current.location) current.location = prev.location; + if (!current.listener) + current.listener = prev.listener; } } } @@ -849,6 +851,12 @@ PyLocation *PyThreadContextEntry::getLocation() { return nb::cast(location); } +PyRewriterBaseListener *PyThreadContextEntry::getListener() { + if (!listener) + return nullptr; + return nb::cast(listener); +} + PyMlirContext *PyThreadContextEntry::getDefaultContext() { auto *tos = getTopOfStack(); return tos ? tos->getContext() : nullptr; @@ -864,10 +872,16 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() { return tos ? tos->getLocation() : nullptr; } +PyRewriterBaseListener *PyThreadContextEntry::getDefaultListener() { + auto *tos = getTopOfStack(); + return tos ? tos->getListener() : nullptr; +} + nb::object PyThreadContextEntry::pushContext(nb::object context) { push(FrameKind::Context, /*context=*/context, /*insertionPoint=*/nb::object(), - /*location=*/nb::object()); + /*location=*/nb::object(), + /*listener=*/nb::object()); return context; } @@ -890,7 +904,8 @@ PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { push(FrameKind::InsertionPoint, /*context=*/contextObj, /*insertionPoint=*/insertionPointObj, - /*location=*/nb::object()); + /*location=*/nb::object(), + /*listener=*/nb::object()); return insertionPointObj; } @@ -910,7 +925,8 @@ nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { nb::object contextObj = location.getContext().getObject(); push(FrameKind::Location, /*context=*/contextObj, /*insertionPoint=*/nb::object(), - /*location=*/locationObj); + /*location=*/locationObj, + /*listener=*/nb::object()); return locationObj; } @@ -924,6 +940,27 @@ void PyThreadContextEntry::popLocation(PyLocation &location) { stack.pop_back(); } +nb::object PyThreadContextEntry::pushListener(nb::object listenerObj) { + PyRewriterBaseListener &listener = + nb::cast(listenerObj); + nb::object contextObj = listener.getContext().getObject(); + push(FrameKind::Location, /*context=*/contextObj, + /*insertionPoint=*/nb::object(), + /*location=*/nb::object(), + /*listener=*/listenerObj); + return listenerObj; +} + +void PyThreadContextEntry::popListener(PyRewriterBaseListener &listener) { + auto &stack = getStack(); + if (stack.empty()) + throw std::runtime_error("Unbalanced Listener enter/exit"); + auto &tos = stack.back(); + if (tos.frameKind != FrameKind::Listener && tos.getListener() != &listener) + throw std::runtime_error("Unbalanced Listener enter/exit"); + stack.pop_back(); +} + //------------------------------------------------------------------------------ // PyDiagnostic* //------------------------------------------------------------------------------ @@ -1417,6 +1454,10 @@ static void maybeInsertOperation(PyOperationRef &op, if (ip) ip->insert(*op.get()); } + if (PyRewriterBaseListener *listener = + PyThreadContextEntry::getDefaultListener()) { + listener->notifyOperationInserted(*op.get()); + } } nb::object PyOperation::create(std::string_view name, @@ -2036,6 +2077,19 @@ PyOpView::PyOpView(const nb::object &operationObject) : operation(nb::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} +//------------------------------------------------------------------------------ +// PyRewriterBaseListener. +//------------------------------------------------------------------------------ + +nb::object PyRewriterBaseListener::contextEnter(nb::object listener) { + return PyThreadContextEntry::pushListener(std::move(listener)); +} + +void PyRewriterBaseListener::contextExit(nb::handle excType, nb::handle excVal, + nb::handle excTb) { + PyThreadContextEntry::popListener(*this); +} + //------------------------------------------------------------------------------ // PyInsertionPoint. //------------------------------------------------------------------------------ @@ -3961,6 +4015,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Returns the list of Block predecessors."); + //---------------------------------------------------------------------------- + // Mapping of PyRewriterBaseListener. + //---------------------------------------------------------------------------- + nb::class_(m, "RewriterBaseListener") + .def("__enter__", &PyRewriterBaseListener::contextEnter) + .def("__exit__", &PyRewriterBaseListener::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()); + //---------------------------------------------------------------------------- // Mapping of PyInsertionPoint. //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index e706be3b4d32a..9985af0a448a6 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -46,6 +46,7 @@ class PyOperationBase; class PyType; class PySymbolTable; class PyValue; +class PyRewriterBaseListener; /// Template for a reference to a concrete type which captures a python /// reference to its underlying python object. @@ -115,13 +116,15 @@ class PyThreadContextEntry { Context, InsertionPoint, Location, + Listener, }; PyThreadContextEntry(FrameKind frameKind, nanobind::object context, nanobind::object insertionPoint, - nanobind::object location) + nanobind::object location, nanobind::object listener) : context(std::move(context)), insertionPoint(std::move(insertionPoint)), - location(std::move(location)), frameKind(frameKind) {} + location(std::move(location)), listener(std::move(listener)), + frameKind(frameKind) {} /// Gets the top of stack context and return nullptr if not defined. static PyMlirContext *getDefaultContext(); @@ -132,9 +135,12 @@ class PyThreadContextEntry { /// Gets the top of stack location and returns nullptr if not defined. static PyLocation *getDefaultLocation(); + static PyRewriterBaseListener *getDefaultListener(); + PyMlirContext *getContext(); PyInsertionPoint *getInsertionPoint(); PyLocation *getLocation(); + PyRewriterBaseListener *getListener(); FrameKind getFrameKind() { return frameKind; } /// Stack management. @@ -145,13 +151,16 @@ class PyThreadContextEntry { static void popInsertionPoint(PyInsertionPoint &insertionPoint); static nanobind::object pushLocation(nanobind::object location); static void popLocation(PyLocation &location); + static nanobind::object pushListener(nanobind::object listener); + static void popListener(PyRewriterBaseListener &listener); /// Gets the thread local stack. static std::vector &getStack(); private: static void push(FrameKind frameKind, nanobind::object context, - nanobind::object insertionPoint, nanobind::object location); + nanobind::object insertionPoint, nanobind::object location, + nanobind::object listener); /// An object reference to the PyContext. nanobind::object context; @@ -159,6 +168,8 @@ class PyThreadContextEntry { nanobind::object insertionPoint; /// An object reference to the current location. nanobind::object location; + /// An object reference to the current listener. + nanobind::object listener; // The kind of push that was performed. FrameKind frameKind; }; @@ -830,6 +841,31 @@ class PyBlock { MlirBlock block; }; +/// Wrapper around a MlirRewriterBaseListener. +class PyRewriterBaseListener { +public: + PyRewriterBaseListener(MlirRewriterBaseListener listener, + PyMlirContextRef ctx) + : listener(listener), ctx(std::move(ctx)) {} + + MlirRewriterBaseListener get() { return listener; } + + void notifyOperationInserted(PyOperationBase &op) { + mlirRewriterBaseListenerNotifyOperationInserted(get(), op.getOperation(), + MlirOperation{nullptr}); + } + + PyMlirContextRef getContext() { return ctx; } + + static nanobind::object contextEnter(nanobind::object listener); + void contextExit(nanobind::handle excType, nanobind::handle excVal, + nanobind::handle excTb); + +private: + MlirRewriterBaseListener listener; + PyMlirContextRef ctx; +}; + /// An insertion point maintains a pointer to a Block and a reference operation. /// Calls to insert() will insert a new operation before the /// reference operation. If the reference operation is null, then appends to diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 5ddb3fbbb1317..5512fb2377d60 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -18,6 +18,7 @@ // clang-format on #include "mlir/Config/mlir-config.h" #include "nanobind/nanobind.h" +#include "llvm/ADT/ScopeExit.h" namespace nb = nanobind; using namespace mlir; @@ -45,6 +46,10 @@ class PyPatternRewriter { return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } + PyRewriterBaseListener getListener() { + return PyRewriterBaseListener(mlirRewriterBaseGetListener(base), ctx); + } + void replaceOp(MlirOperation op, MlirOperation newOp) { mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); } @@ -202,7 +207,15 @@ class PyRewritePatternSet { PyMlirContext::forContext(mlirOperationGetContext(op)); nb::object opView = PyOperation::forOperation(ctx, op)->createOpView(); - nb::object res = f(opView, PyPatternRewriter(rewriter)); + PyPatternRewriter pyRewriter(rewriter); + nb::object listener = nb::cast(pyRewriter.getListener()); + + listener.attr("__enter__")(); + auto exit = llvm::make_scope_exit([listener] { + listener.attr("__exit__")(nb::none(), nb::none(), nb::none()); + }); + nb::object res = f(opView, pyRewriter); + return logicalResultFromObject(res); }; MlirRewritePattern pattern = mlirOpRewritePattenCreate( @@ -234,6 +247,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { class_(m, "PatternRewriter") .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, "The current insertion point of the PatternRewriter.") + .def_prop_ro("listener", &PyPatternRewriter::getListener, + "The rewrite listener of the PatternRewriter.") .def( "replace_op", [](PyPatternRewriter &self, MlirOperation op, diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 46c329d8433b4..41b48bdfe5d6b 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -29,6 +29,23 @@ MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { return wrap(unwrap(rewriter)->getContext()); } +MlirRewriterBaseListener +mlirRewriterBaseGetListener(MlirRewriterBase rewriter) { + return wrap( + dyn_cast(unwrap(rewriter)->getListener())); +} + +void mlirRewriterBaseListenerNotifyOperationInserted( + MlirRewriterBaseListener listener, MlirOperation op, + MlirOperation insertionPoint) { + OpBuilder::InsertPoint ip; + if (!mlirOperationIsNull(insertionPoint)) { + ip = OpBuilder::InsertPoint(unwrap(insertionPoint)->getBlock(), + Block::iterator(unwrap(insertionPoint))); + } + return unwrap(listener)->notifyOperationInserted(unwrap(op), ip); +} + //===----------------------------------------------------------------------===// /// Insertion points methods //===----------------------------------------------------------------------===//