From 3532e160e87cf8304c1e8c1f9c1ecec061b75f52 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 3 Oct 2025 05:52:07 +0000 Subject: [PATCH] [mlir] Add comment for failed verification in print. While we have debug output explaining verification failure, many users are confused when they first encounter this/most folks don't run with --debug. Move the checking such that we can emit a comment explaining why/make it more discoverable. --- mlir/lib/IR/AsmPrinter.cpp | 85 +++++++++++++---------- mlir/test/IR/invalid-warning-comment.mlir | 4 ++ 2 files changed, 54 insertions(+), 35 deletions(-) create mode 100644 mlir/test/IR/invalid-warning-comment.mlir diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 3d19c5ad8fbca..7e49bfc9da64d 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1941,12 +1941,43 @@ void FallbackAsmResourceMap::ResourceCollection::buildResources( namespace mlir { namespace detail { + +/// Verifies the operation and switches to generic op printing if verification +/// fails. We need to do this because custom print functions may fail/crash for +/// invalid ops. +static void verifyOpAndAdjustFlags(Operation *op, OpPrintingFlags &printerFlags, + bool &failedVerification) { + if (printerFlags.shouldPrintGenericOpForm() || + printerFlags.shouldAssumeVerified()) + return; + + // Ignore errors emitted by the verifier. We check the thread id to avoid + // consuming other threads' errors. + auto parentThreadId = llvm::get_threadid(); + ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { + if (parentThreadId == llvm::get_threadid()) { + LLVM_DEBUG({ + diag.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + return success(); + } + return failure(); + }); + if (failed(verify(op))) { + printerFlags.printGenericOpForm(); + failedVerification = true; + } +} + class AsmStateImpl { public: explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap) : interfaces(op->getContext()), nameState(op, printerFlags), - printerFlags(printerFlags), locationMap(locationMap) {} + printerFlags(printerFlags), locationMap(locationMap) { + verifyOpAndAdjustFlags(op, this->printerFlags, failedVerification); + } explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap) : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} @@ -1998,6 +2029,8 @@ class AsmStateImpl { void popCyclicPrinting() { cyclicPrintingStack.pop_back(); } + bool verificationFailed() const { return failedVerification; } + private: /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection interfaces; @@ -2020,6 +2053,10 @@ class AsmStateImpl { /// Flags that control op output. OpPrintingFlags printerFlags; + /// Whether the operation from which the AsmState was created, failed + /// verification. + bool failedVerification = false; + /// An optional location map to be populated. AsmState::LocationMap *locationMap; @@ -2047,41 +2084,9 @@ void printDimensionList(raw_ostream &stream, Range &&shape) { } // namespace detail } // namespace mlir -/// Verifies the operation and switches to generic op printing if verification -/// fails. We need to do this because custom print functions may fail for -/// invalid ops. -static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, - OpPrintingFlags printerFlags) { - if (printerFlags.shouldPrintGenericOpForm() || - printerFlags.shouldAssumeVerified()) - return printerFlags; - - // Ignore errors emitted by the verifier. We check the thread id to avoid - // consuming other threads' errors. - auto parentThreadId = llvm::get_threadid(); - ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { - if (parentThreadId == llvm::get_threadid()) { - LLVM_DEBUG({ - diag.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - return success(); - } - return failure(); - }); - if (failed(verify(op))) { - LDBG() << op->getName() - << "' failed to verify and will be printed in generic form"; - printerFlags.printGenericOpForm(); - } - - return printerFlags; -} - AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, LocationMap *locationMap, FallbackAsmResourceMap *map) - : impl(std::make_unique( - op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) { + : impl(std::make_unique(op, printerFlags, locationMap)) { if (map) attachFallbackResourcePrinter(*map); } @@ -3245,7 +3250,8 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { using Impl::printType; explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) - : Impl(os, state), OpAsmPrinter(static_cast(*this)) {} + : Impl(os, state), OpAsmPrinter(static_cast(*this)), + verificationFailed(state.verificationFailed()) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); @@ -3433,10 +3439,19 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { // This is the current indentation level for nested structures. unsigned currentIndent = 0; + + /// Whether the operation from which the AsmState was created, failed + /// verification. + bool verificationFailed = false; }; } // namespace void OperationPrinter::printTopLevelOperation(Operation *op) { + if (verificationFailed) { + os << "// '" << op->getName() + << "' failed to verify and will be printed in generic form\n"; + } + // Output the aliases at the top level that can't be deferred. state.getAliasState().printNonDeferredAliases(*this, newLine); diff --git a/mlir/test/IR/invalid-warning-comment.mlir b/mlir/test/IR/invalid-warning-comment.mlir new file mode 100644 index 0000000000000..08ace2f8139e4 --- /dev/null +++ b/mlir/test/IR/invalid-warning-comment.mlir @@ -0,0 +1,4 @@ +// RUN: mlir-opt --mlir-very-unsafe-disable-verifier-on-parsing %s | FileCheck %s + +// CHECK: // 'builtin.module' failed to verify and will be printed in generic form +func.func @foo() -> tensor<10xi32> { return }