diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 16893c6db87b1..f0b0979a81ee3 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -17,6 +17,7 @@ #include namespace mlir { +class PassInstrumentation; namespace detail { class OpToOpPassAdaptor; struct OpPassManagerImpl; @@ -334,6 +335,9 @@ class Pass { /// Allow access to 'passOptions'. friend class PassInfo; + + /// Allow access to 'signalPassFailure'. + friend class PassInstrumentation; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h index 917bac4b22288..25a8e77be75ee 100644 --- a/mlir/include/mlir/Pass/PassInstrumentation.h +++ b/mlir/include/mlir/Pass/PassInstrumentation.h @@ -80,6 +80,8 @@ class PassInstrumentation { /// name of the analysis that was computed, its TypeID, as well as the /// current operation being analyzed. virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {} + + static void signalPassFailure(Pass *pass); }; /// This class holds a collection of PassInstrumentation objects, and invokes diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 521c7c6be17b6..17ac475b42f4b 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -599,17 +599,20 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, if (pi) pi->runBeforePass(pass, op); - bool passFailed = false; - op->getContext()->executeAction( - [&]() { - // Invoke the virtual runOnOperation method. - if (auto *adaptor = dyn_cast(pass)) - adaptor->runOnOperation(verifyPasses); - else - pass->runOnOperation(); - passFailed = pass->passState->irAndPassFailed.getInt(); - }, - {op}, *pass); + bool passFailed = pass->passState->irAndPassFailed.getInt(); + if (!passFailed) { + op->getContext()->executeAction( + [&]() { + // Invoke the virtual runOnOperation method. + if (auto *adaptor = dyn_cast(pass)) + adaptor->runOnOperation(verifyPasses); + else + pass->runOnOperation(); + passFailed = pass->passState->irAndPassFailed.getInt(); + }, + {op}, *pass); + } + // Invalidate any non preserved analyses. am.invalidate(pass->passState->preservedAnalyses); @@ -640,10 +643,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, // Instrument after the pass has run. if (pi) { - if (passFailed) + if (passFailed) { pi->runAfterPassFailed(pass, op); - else + } else { pi->runAfterPass(pass, op); + passFailed = passFailed || pass->passState->irAndPassFailed.getInt(); + } } // Return if the pass signaled a failure. @@ -1198,6 +1203,8 @@ void PassInstrumentation::runBeforePipeline( void PassInstrumentation::runAfterPipeline( std::optional name, const PipelineParentInfo &parentInfo) {} +void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); } + //===----------------------------------------------------------------------===// // PassInstrumentor //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp index 7e618811eabf4..86c793384db11 100644 --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassInstrumentation.h" #include "gtest/gtest.h" #include @@ -117,6 +118,103 @@ struct AddSecondAttrFunctionPass } }; +/// PassInstrumentation to count pass callbacks and signal pass failures. +struct TestPassInstrumentation : public PassInstrumentation { + int beforePassCallbackCount = 0; + int afterPassCallbackCount = 0; + int afterPassFailedCallbackCount = 0; + + bool failBeforePass = false; + bool failAfterPass = false; + + void runBeforePass(Pass *pass, Operation *op) override { + if (pass->getTypeID() != TypeID::get()) return; + + ++beforePassCallbackCount; + if (failBeforePass) + signalPassFailure(pass); + } + void runAfterPass(Pass *pass, Operation *op) override { + if (pass->getTypeID() != TypeID::get()) return; + + ++afterPassCallbackCount; + if (failAfterPass) + signalPassFailure(pass); + } + void runAfterPassFailed(Pass *pass, Operation *op) override { + if (pass->getTypeID() != TypeID::get()) return; + + ++afterPassFailedCallbackCount; + } +}; + +TEST(PassManagerTest, PassInstrumentation) { + MLIRContext context; + context.loadDialect(); + Builder b(&context); + + // Create a module with 1 function. + OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); + auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func", + b.getFunctionType({}, {})); + func.setPrivate(); + module->push_back(func); + + struct InstrumentationCounts { + int beforePass; + int afterPass; + int afterPassFailed; + }; + + auto runInstrumentation = + [&](bool failBefore, + bool failAfter) -> std::pair { + // Instantiate and run our pass. + auto pm = PassManager::on(&context); + auto instrumentation = std::make_unique(); + auto *instrumentationPtr = instrumentation.get(); + instrumentation->failBeforePass = failBefore; + instrumentation->failAfterPass = failAfter; + pm.addInstrumentation(std::move(instrumentation)); + pm.addNestedPass(std::make_unique()); + LogicalResult result = pm.run(module.get()); + + InstrumentationCounts counts = { + .beforePass = instrumentationPtr->beforePassCallbackCount, + .afterPass = instrumentationPtr->afterPassCallbackCount, + .afterPassFailed = instrumentationPtr->afterPassFailedCallbackCount}; + return {result, counts}; + }; + + for (bool failBefore : {false, true}) { + for (bool failAfter : {false, true}) { + auto [result, counts] = runInstrumentation(failBefore, failAfter); + + InstrumentationCounts expected; + if (failBefore) { + EXPECT_TRUE(failed(result)) + << "failBefore=" << failBefore << ", failAfter=" << failAfter; + expected = {.beforePass = 1, .afterPass = 0, .afterPassFailed = 1}; + } else if (failAfter) { + EXPECT_TRUE(failed(result)) + << "failBefore=" << failBefore << ", failAfter=" << failAfter; + expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0}; + } else { + EXPECT_TRUE(succeeded(result)) + << "failBefore=" << failBefore << ", failAfter=" << failAfter; + expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0}; + } + + EXPECT_EQ(counts.beforePass, expected.beforePass) + << "failBefore=" << failBefore << ", failAfter=" << failAfter; + EXPECT_EQ(counts.afterPass, expected.afterPass) + << "failBefore=" << failBefore << ", failAfter=" << failAfter; + EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed) + << "failBefore=" << failBefore << ", failAfter=" << failAfter; + } + } +} + TEST(PassManagerTest, ExecutionAction) { MLIRContext context; context.loadDialect();