Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Pass/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <optional>

namespace mlir {
class PassInstrumentation;
namespace detail {
class OpToOpPassAdaptor;
struct OpPassManagerImpl;
Expand Down Expand Up @@ -334,6 +335,9 @@ class Pass {

/// Allow access to 'passOptions'.
friend class PassInfo;

/// Allow access to 'signalPassFailure'.
friend class PassInstrumentation;
};

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Pass/PassInstrumentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class PassInstrumentation {
/// name of the analysis that was computed, its TypeID, as well as the
/// current operation being analyzed.
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}

static void signalPassFailure(Pass *pass);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documentation for this method, also should this be a static method or an instance method? I know it doesn't have to be an instance method but that would help keep the scope of API exposure slimmer (otherwise, should we just make signalPassFailure public?)

};

/// This class holds a collection of PassInstrumentation objects, and invokes
Expand Down
33 changes: 20 additions & 13 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,17 +599,20 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
if (pi)
pi->runBeforePass(pass, op);

bool passFailed = false;
op->getContext()->executeAction<PassExecutionAction>(
[&]() {
// Invoke the virtual runOnOperation method.
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
adaptor->runOnOperation(verifyPasses);
else
pass->runOnOperation();
passFailed = pass->passState->irAndPassFailed.getInt();
},
{op}, *pass);
bool passFailed = pass->passState->irAndPassFailed.getInt();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is non-intuitive to me, should be documented.

if (!passFailed) {
op->getContext()->executeAction<PassExecutionAction>(
[&]() {
// Invoke the virtual runOnOperation method.
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
adaptor->runOnOperation(verifyPasses);
else
pass->runOnOperation();
passFailed = pass->passState->irAndPassFailed.getInt();
},
{op}, *pass);
}


// Invalidate any non preserved analyses.
am.invalidate(pass->passState->preservedAnalyses);
Expand Down Expand Up @@ -640,10 +643,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,

// Instrument after the pass has run.
if (pi) {
if (passFailed)
if (passFailed) {
pi->runAfterPassFailed(pass, op);
else
} else {
pi->runAfterPass(pass, op);
passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
}
}

// Return if the pass signaled a failure.
Expand Down Expand Up @@ -1198,6 +1203,8 @@ void PassInstrumentation::runBeforePipeline(
void PassInstrumentation::runAfterPipeline(
std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}

void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); }

//===----------------------------------------------------------------------===//
// PassInstrumentor
//===----------------------------------------------------------------------===//
Expand Down
98 changes: 98 additions & 0 deletions mlir/unittests/Pass/PassManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassInstrumentation.h"
#include "gtest/gtest.h"

#include <memory>
Expand Down Expand Up @@ -117,6 +118,103 @@ struct AddSecondAttrFunctionPass
}
};

/// PassInstrumentation to count pass callbacks and signal pass failures.
struct TestPassInstrumentation : public PassInstrumentation {
int beforePassCallbackCount = 0;
int afterPassCallbackCount = 0;
int afterPassFailedCallbackCount = 0;

bool failBeforePass = false;
bool failAfterPass = false;

void runBeforePass(Pass *pass, Operation *op) override {
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;

++beforePassCallbackCount;
if (failBeforePass)
signalPassFailure(pass);
}
void runAfterPass(Pass *pass, Operation *op) override {
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;

++afterPassCallbackCount;
if (failAfterPass)
signalPassFailure(pass);
}
void runAfterPassFailed(Pass *pass, Operation *op) override {
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;

++afterPassFailedCallbackCount;
}
};

TEST(PassManagerTest, PassInstrumentation) {
MLIRContext context;
context.loadDialect<func::FuncDialect>();
Builder b(&context);

// Create a module with 1 function.
OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
b.getFunctionType({}, {}));
func.setPrivate();
module->push_back(func);

struct InstrumentationCounts {
int beforePass;
int afterPass;
int afterPassFailed;
};

auto runInstrumentation =
[&](bool failBefore,
bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
// Instantiate and run our pass.
auto pm = PassManager::on<ModuleOp>(&context);
auto instrumentation = std::make_unique<TestPassInstrumentation>();
auto *instrumentationPtr = instrumentation.get();
instrumentation->failBeforePass = failBefore;
instrumentation->failAfterPass = failAfter;
pm.addInstrumentation(std::move(instrumentation));
pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
LogicalResult result = pm.run(module.get());

InstrumentationCounts counts = {
.beforePass = instrumentationPtr->beforePassCallbackCount,
.afterPass = instrumentationPtr->afterPassCallbackCount,
.afterPassFailed = instrumentationPtr->afterPassFailedCallbackCount};
return {result, counts};
};

for (bool failBefore : {false, true}) {
for (bool failAfter : {false, true}) {
auto [result, counts] = runInstrumentation(failBefore, failAfter);

InstrumentationCounts expected;
if (failBefore) {
EXPECT_TRUE(failed(result))
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
expected = {.beforePass = 1, .afterPass = 0, .afterPassFailed = 1};
} else if (failAfter) {
EXPECT_TRUE(failed(result))
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
} else {
EXPECT_TRUE(succeeded(result))
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
}

EXPECT_EQ(counts.beforePass, expected.beforePass)
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
EXPECT_EQ(counts.afterPass, expected.afterPass)
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
}
}
}

TEST(PassManagerTest, ExecutionAction) {
MLIRContext context;
context.loadDialect<func::FuncDialect>();
Expand Down
Loading