Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Pass/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ class PassManager : public OpPassManager {
/// Runs the verifier after each individual pass.
void enableVerifier(bool enabled = true);

/// Sets whether an error containing the failing pass name should be emitted
/// upon failure.
void enableErrorOnFailure(bool enabled = true);

//===--------------------------------------------------------------------===//
// Instrumentations
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -497,6 +501,10 @@ class PassManager : public OpPassManager {

/// A flag that indicates if the IR should be verified in between passes.
bool verifyPasses : 1;

/// A flag that indicates if an error containing the pass name should be
/// emitted upon failure.
bool emitErrorOnFailure : 1;
};

/// Register a set of useful command-line options that can be used to configure
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ class MlirOptMainConfig {
}
bool shouldVerifyPasses() const { return verifyPassesFlag; }

/// Set whether to emit and error upon pass failure.
MlirOptMainConfig &emitErrorOnPassFailure(bool emit) {
emitErrorOnPassFailureFlag = emit;
return *this;
}
bool shouldEmitErrorOnPassFailure() const {
return emitErrorOnPassFailureFlag;
}

/// Set whether to run the verifier on parsing.
MlirOptMainConfig &verifyOnParsing(bool verify) {
disableVerifierOnParsingFlag = !verify;
Expand Down Expand Up @@ -291,6 +300,9 @@ class MlirOptMainConfig {
/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;

/// Emit an error upon a pass failure.
bool emitErrorOnPassFailureFlag = false;

/// Disable the verifier on parsing.
bool disableVerifierOnParsingFlag = false;

Expand Down
45 changes: 30 additions & 15 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ llvm::hash_code OpPassManager::hash() {

LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
AnalysisManager am, bool verifyPasses,
bool emitErrorOnFailure,
unsigned parentInitGeneration) {
std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
if (!opInfo)
Expand Down Expand Up @@ -533,9 +534,9 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
return failure();
AnalysisManager nestedAm = root == op ? am : am.nest(root);
return OpToOpPassAdaptor::runPipeline(pipeline, root, nestedAm,
verifyPasses, parentInitGeneration,
pi, &parentInfo);
return OpToOpPassAdaptor::runPipeline(
pipeline, root, nestedAm, verifyPasses, emitErrorOnFailure,
parentInitGeneration, pi, &parentInfo);
};
pass->passState.emplace(op, am, dynamicPipelineCallback);

Expand All @@ -548,7 +549,7 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
[&]() {
// Invoke the virtual runOnOperation method.
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
adaptor->runOnOperation(verifyPasses);
adaptor->runOnOperation(verifyPasses, emitErrorOnFailure);
else
pass->runOnOperation();
passFailed = pass->passState->irAndPassFailed.getInt();
Expand Down Expand Up @@ -597,7 +598,8 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
/// Run the given operation and analysis manager on a provided op pass manager.
LogicalResult OpToOpPassAdaptor::runPipeline(
OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses,
unsigned parentInitGeneration, PassInstrumentor *instrumentor,
bool emitErrorOnFailure, unsigned parentInitGeneration,
PassInstrumentor *instrumentor,
const PassInstrumentation::PipelineParentInfo *parentInfo) {
assert((!instrumentor || parentInfo) &&
"expected parent info if instrumentor is provided");
Expand All @@ -616,8 +618,12 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
}

for (Pass &pass : pm.getPasses())
if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration)))
if (failed(run(&pass, op, am, verifyPasses, emitErrorOnFailure,
parentInitGeneration))) {
if (emitErrorOnFailure)
return op->emitError("failed to run pass: ") << pass.getName();
return failure();
}

if (instrumentor) {
instrumentor->runAfterPipeline(pm.getOpName(*op->getContext()),
Expand Down Expand Up @@ -735,15 +741,17 @@ void OpToOpPassAdaptor::runOnOperation() {
}

/// Run the held pipeline over all nested operations.
void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
void OpToOpPassAdaptor::runOnOperation(bool verifyPasses,
bool emitErrorOnFailure) {
if (getContext().isMultithreadingEnabled())
runOnOperationAsyncImpl(verifyPasses);
runOnOperationAsyncImpl(verifyPasses, emitErrorOnFailure);
else
runOnOperationImpl(verifyPasses);
runOnOperationImpl(verifyPasses, emitErrorOnFailure);
}

/// Run this pass adaptor synchronously.
void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses,
bool emitErrorOnFailure) {
auto am = getAnalysisManager();
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
this};
Expand All @@ -758,7 +766,8 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
// Run the held pipeline over the current operation.
unsigned initGeneration = mgr->impl->initializationGeneration;
if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses,
initGeneration, instrumentor, &parentInfo)))
emitErrorOnFailure, initGeneration, instrumentor,
&parentInfo)))
signalPassFailure();
}
}
Expand All @@ -775,7 +784,8 @@ static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
}

/// Run this pass adaptor synchronously.
void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses,
bool emitErrorOnFailure) {
AnalysisManager am = getAnalysisManager();
MLIRContext *context = &getContext();

Expand Down Expand Up @@ -838,7 +848,7 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
// Get the pass manager for this operation and execute it.
OpPassManager &pm = asyncExecutors[pmIndex][opInfo.passManagerIdx];
LogicalResult pipelineResult = runPipeline(
pm, opInfo.op, opInfo.am, verifyPasses,
pm, opInfo.op, opInfo.am, verifyPasses, emitErrorOnFailure,
pm.impl->initializationGeneration, instrumentor, &parentInfo);
if (failed(pipelineResult))
hasFailure.store(true);
Expand All @@ -859,17 +869,21 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
PassManager::PassManager(MLIRContext *ctx, StringRef operationName,
Nesting nesting)
: OpPassManager(operationName, nesting), context(ctx), passTiming(false),
verifyPasses(true) {}
verifyPasses(true), emitErrorOnFailure(false) {}

PassManager::PassManager(OperationName operationName, Nesting nesting)
: OpPassManager(operationName, nesting),
context(operationName.getContext()), passTiming(false),
verifyPasses(true) {}
verifyPasses(true), emitErrorOnFailure(false) {}

PassManager::~PassManager() = default;

void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }

void PassManager::enableErrorOnFailure(bool enabled) {
emitErrorOnFailure = enabled;
}

/// Run the passes within this manager on the provided operation.
LogicalResult PassManager::run(Operation *op) {
MLIRContext *context = getContext();
Expand Down Expand Up @@ -931,6 +945,7 @@ void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {

LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) {
return OpToOpPassAdaptor::runPipeline(*this, op, am, verifyPasses,
emitErrorOnFailure,
impl->initializationGeneration);
}

Expand Down
18 changes: 11 additions & 7 deletions mlir/lib/Pass/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OpToOpPassAdaptor
OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs) = default;

/// Run the held pipeline over all operations.
void runOnOperation(bool verifyPasses);
void runOnOperation(bool verifyPasses, bool emitErrorOnFailure);
void runOnOperation() override;

/// Try to merge the current pass adaptor into 'rhs'. This will try to append
Expand Down Expand Up @@ -60,25 +60,29 @@ class OpToOpPassAdaptor

private:
/// Run this pass adaptor synchronously.
void runOnOperationImpl(bool verifyPasses);
void runOnOperationImpl(bool verifyPasses, bool emitErrorOnFailure);

/// Run this pass adaptor asynchronously.
void runOnOperationAsyncImpl(bool verifyPasses);
void runOnOperationAsyncImpl(bool verifyPasses, bool emitErrorOnFailure);

/// Run the given operation and analysis manager on a single pass.
/// `parentInitGeneration` is the initialization generation of the parent pass
/// manager, and is used to initialize any dynamic pass pipelines run by the
/// given pass.
/// given pass. If `emitErrorOnFailure` is set, when a pass in
/// the pipeline fails its name will be emitted in an error.
static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am,
bool verifyPasses, unsigned parentInitGeneration);
bool verifyPasses, bool emitErrorOnFailure,
unsigned parentInitGeneration);

/// Run the given operation and analysis manager on a provided op pass
/// manager. `parentInitGeneration` is the initialization generation of the
/// parent pass manager, and is used to initialize any dynamic pass pipelines
/// run by the given passes.
/// run by the given passes. If `emitErrorOnFailure` is set, when a pass in
/// the pipeline fails its name will be emitted in an error.
static LogicalResult runPipeline(
OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses,
unsigned parentInitGeneration, PassInstrumentor *instrumentor = nullptr,
bool emitErrorOnFailure, unsigned parentInitGeneration,
PassInstrumentor *instrumentor = nullptr,
const PassInstrumentation::PipelineParentInfo *parentInfo = nullptr);

/// A set of adaptors to run.
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Run the verifier after each transformation pass"),
cl::location(verifyPassesFlag), cl::init(true));

static cl::opt<bool, /*ExternalStorage=*/true> emitPassErrorOnFailure(
"emit-pass-error-on-failure",
cl::desc("Emit an error with the pass name when a pass fails"),
cl::location(emitErrorOnPassFailureFlag), cl::init(false));

static cl::opt<bool, /*ExternalStorage=*/true> disableVerifyOnParsing(
"mlir-very-unsafe-disable-verifier-on-parsing",
cl::desc("Disable the verifier on parsing (very unsafe)"),
Expand Down Expand Up @@ -465,6 +470,7 @@ performActions(raw_ostream &os,
// Prepare the pass manager, applying command-line and reproducer options.
PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
pm.enableVerifier(config.shouldVerifyPasses());
pm.enableErrorOnFailure(config.shouldEmitErrorOnPassFailure());
if (failed(applyPassManagerCLOptions(pm)))
return failure();
pm.enableTiming(timing);
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Pass/pass-name-diagnostics.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-pass-failure{gen-diagnostics}))' -emit-pass-error-on-failure -verify-diagnostics=only-expected

// expected-error@+2 {{failed to run pass}}
// expected-error@+1 {{illegal operation}}
func.func @TestAlwaysIllegalOperationPass1() {
return
}