Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1941,12 +1941,43 @@ void FallbackAsmResourceMap::ResourceCollection::buildResources(

namespace mlir {
namespace detail {

/// Verifies the operation and switches to generic op printing if verification
/// fails. We need to do this because custom print functions may fail/crash for
/// invalid ops.
static void verifyOpAndAdjustFlags(Operation *op, OpPrintingFlags &printerFlags,
bool &failedVerification) {
if (printerFlags.shouldPrintGenericOpForm() ||
printerFlags.shouldAssumeVerified())
return;

// Ignore errors emitted by the verifier. We check the thread id to avoid
// consuming other threads' errors.
auto parentThreadId = llvm::get_threadid();
ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
if (parentThreadId == llvm::get_threadid()) {
LLVM_DEBUG({
diag.print(llvm::dbgs());
llvm::dbgs() << "\n";
});
return success();
}
return failure();
});
if (failed(verify(op))) {
printerFlags.printGenericOpForm();
failedVerification = true;
}
}

class AsmStateImpl {
public:
explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
AsmState::LocationMap *locationMap)
: interfaces(op->getContext()), nameState(op, printerFlags),
printerFlags(printerFlags), locationMap(locationMap) {}
printerFlags(printerFlags), locationMap(locationMap) {
verifyOpAndAdjustFlags(op, this->printerFlags, failedVerification);
}
explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
AsmState::LocationMap *locationMap)
: interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
Expand Down Expand Up @@ -1998,6 +2029,8 @@ class AsmStateImpl {

void popCyclicPrinting() { cyclicPrintingStack.pop_back(); }

bool verificationFailed() const { return failedVerification; }

private:
/// Collection of OpAsm interfaces implemented in the context.
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
Expand All @@ -2020,6 +2053,10 @@ class AsmStateImpl {
/// Flags that control op output.
OpPrintingFlags printerFlags;

/// Whether the operation from which the AsmState was created, failed
/// verification.
bool failedVerification = false;

/// An optional location map to be populated.
AsmState::LocationMap *locationMap;

Expand Down Expand Up @@ -2047,41 +2084,9 @@ void printDimensionList(raw_ostream &stream, Range &&shape) {
} // namespace detail
} // namespace mlir

/// Verifies the operation and switches to generic op printing if verification
/// fails. We need to do this because custom print functions may fail for
/// invalid ops.
static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
OpPrintingFlags printerFlags) {
if (printerFlags.shouldPrintGenericOpForm() ||
printerFlags.shouldAssumeVerified())
return printerFlags;

// Ignore errors emitted by the verifier. We check the thread id to avoid
// consuming other threads' errors.
auto parentThreadId = llvm::get_threadid();
ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
if (parentThreadId == llvm::get_threadid()) {
LLVM_DEBUG({
diag.print(llvm::dbgs());
llvm::dbgs() << "\n";
});
return success();
}
return failure();
});
if (failed(verify(op))) {
LDBG() << op->getName()
<< "' failed to verify and will be printed in generic form";
printerFlags.printGenericOpForm();
}

return printerFlags;
}

AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
LocationMap *locationMap, FallbackAsmResourceMap *map)
: impl(std::make_unique<AsmStateImpl>(
op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {
: impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {
if (map)
attachFallbackResourcePrinter(*map);
}
Expand Down Expand Up @@ -3245,7 +3250,8 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
using Impl::printType;

explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
: Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
: Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)),
verificationFailed(state.verificationFailed()) {}

/// Print the given top-level operation.
void printTopLevelOperation(Operation *op);
Expand Down Expand Up @@ -3433,10 +3439,19 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {

// This is the current indentation level for nested structures.
unsigned currentIndent = 0;

/// Whether the operation from which the AsmState was created, failed
/// verification.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An AsmState isn't necessarily created from an operation, this is incomplete right now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was true before too - only AsmState created via AsmState::AsmState(Operation *op, had flags adjusted. So it is as complete as before. If the AsmState was not created from an operation, then no operation failed verification and this is set to false.

One could potentially take it to beyond where it is today and in this change, by computing it when one has an Operation (in top level printing?). But one can't cache it then (as one could be printing multiple different ops, some verified and some not). The only usage for that constructor though is inside AsmPrinter and Bytecode IRNumbering and all just when printing Attributes and Types (no ops printed, so no coverage added). So one could also track how AsmState was constructed and reverify on print, or document it as that constructor only being used for Attribute & Type printing inside AsmPrinter and IRNumbering.

bool verificationFailed = false;
};
} // namespace

void OperationPrinter::printTopLevelOperation(Operation *op) {
if (verificationFailed) {
os << "// '" << op->getName()
<< "' failed to verify and will be printed in generic form\n";
}

// Output the aliases at the top level that can't be deferred.
state.getAliasState().printNonDeferredAliases(*this, newLine);

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/IR/invalid-warning-comment.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// RUN: mlir-opt --mlir-very-unsafe-disable-verifier-on-parsing %s | FileCheck %s

// CHECK: // 'builtin.module' failed to verify and will be printed in generic form
func.func @foo() -> tensor<10xi32> { return }
Loading