Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,11 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
/*retTy=*/"void",
/*methodName=*/"generateRuntimeVerification",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
"::mlir::Location":$loc)
"::mlir::Location":$loc,
"function_ref<std::string(Operation *, StringRef)>":$generateErrorMessage)
>,
];

let extraClassDeclaration = [{
/// Generate the error message that will be printed to the user when
/// verification fails.
static std::string generateErrorMessage(Operation *op, const std::string &msg);
}];
}

#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
1 change: 1 addition & 0 deletions mlir/include/mlir/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS
#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
#include "mlir/Transforms/Passes.h.inc"

/// Creates an instance of the Canonicalizer pass, configured with default
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,15 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
passes that are suspected to introduce faulty IR.
}];
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
let options = [
Option<"verboseLevel", "verbose-level", "unsigned", /*default=*/"1",
"Verbosity level for runtime verification messages: "
"0 = Minimum (only source location), "
"1 = Detailed (include full operation details, names, types, shapes, etc.)">
];
}


def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ template <typename T>
struct StructuredOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
StructuredOpInterface<T>, T> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto linalgOp = llvm::cast<LinalgOp>(op);

SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
Expand Down Expand Up @@ -70,7 +72,7 @@ struct StructuredOpInterface
builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
auto cmpOp = builder.createOrFold<index::CmpOp>(
loc, index::IndexCmpPredicate::SGE, min, zero);
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
auto msg = generateErrorMessage(
linalgOp, "unexpected negative result on dimension #" +
std::to_string(dim) + " of input/output operand #" +
std::to_string(opOperand.getOperandNumber()));
Expand Down Expand Up @@ -100,7 +102,7 @@ struct StructuredOpInterface

cmpOp = builder.createOrFold<index::CmpOp>(
loc, predicate, inferredDimSize, actualDimSize);
msg = RuntimeVerifiableOpInterface::generateErrorMessage(
msg = generateErrorMessage(
linalgOp, "dimension #" + std::to_string(dim) +
" of input/output operand #" +
std::to_string(opOperand.getOperandNumber()) +
Expand Down
94 changes: 52 additions & 42 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
struct AssumeAlignmentOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto assumeOp = cast<AssumeAlignmentOp>(op);
Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
assumeOp.getMemref());
Expand All @@ -48,18 +50,20 @@ struct AssumeAlignmentOpInterface
Value isAligned =
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
arith::ConstantIndexOp::create(builder, loc, 0));
cf::AssertOp::create(builder, loc, isAligned,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "memref is not aligned to " +
cf::AssertOp::create(
builder, loc, isAligned,
generateErrorMessage(op, "memref is not aligned to " +
std::to_string(assumeOp.getAlignment())));
}
};

struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());

Expand All @@ -76,8 +80,7 @@ struct CastOpInterface
Value isSameRank = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
cf::AssertOp::create(builder, loc, isSameRank,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "rank mismatch"));
generateErrorMessage(op, "rank mismatch"));
}

// Get source offset and strides. We do not have an op to get offsets and
Expand Down Expand Up @@ -116,8 +119,8 @@ struct CastOpInterface
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
cf::AssertOp::create(
builder, loc, isSameSz,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size mismatch of dim " + std::to_string(it.index())));
generateErrorMessage(op, "size mismatch of dim " +
std::to_string(it.index())));
}

// Get result offset and strides.
Expand All @@ -135,8 +138,7 @@ struct CastOpInterface
Value isSameOffset = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
cf::AssertOp::create(builder, loc, isSameOffset,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "offset mismatch"));
generateErrorMessage(op, "offset mismatch"));
}

// Check strides.
Expand All @@ -153,17 +155,19 @@ struct CastOpInterface
builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
cf::AssertOp::create(
builder, loc, isSameStride,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "stride mismatch of dim " + std::to_string(it.index())));
generateErrorMessage(op, "stride mismatch of dim " +
std::to_string(it.index())));
}
}
};

struct CopyOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
CopyOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto copyOp = cast<CopyOp>(op);
BaseMemRefType sourceType = copyOp.getSource().getType();
BaseMemRefType targetType = copyOp.getTarget().getType();
Expand Down Expand Up @@ -193,9 +197,9 @@ struct CopyOpInterface
Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
Value sameDimSize = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
cf::AssertOp::create(builder, loc, sameDimSize,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size of " + std::to_string(i) +
cf::AssertOp::create(
builder, loc, sameDimSize,
generateErrorMessage(op, "size of " + std::to_string(i) +
"-th source/target dim does not match"));
}
}
Expand All @@ -204,16 +208,17 @@ struct CopyOpInterface
struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto dimOp = cast<DimOp>(op);
Value rank = RankOp::create(builder, loc, dimOp.getSource());
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
cf::AssertOp::create(
builder, loc,
generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "index is out of bounds"));
generateErrorMessage(op, "index is out of bounds"));
}
};

Expand All @@ -223,8 +228,10 @@ template <typename LoadStoreOp>
struct LoadStoreOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto loadStoreOp = cast<LoadStoreOp>(op);

auto memref = loadStoreOp.getMemref();
Expand All @@ -245,16 +252,17 @@ struct LoadStoreOpInterface
: inBounds;
}
cf::AssertOp::create(builder, loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "out-of-bounds access"));
generateErrorMessage(op, "out-of-bounds access"));
}
};

struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto subView = cast<SubViewOp>(op);
MemRefType sourceType = subView.getSource().getType();

Expand All @@ -277,10 +285,10 @@ struct SubViewOpInterface
Value dimSize = metadataOp.getSizes()[i];
Value offsetInBounds =
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
cf::AssertOp::create(
builder, loc, offsetInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "offset " + std::to_string(i) + " is out-of-bounds"));
cf::AssertOp::create(builder, loc, offsetInBounds,
generateErrorMessage(op, "offset " +
std::to_string(i) +
" is out-of-bounds"));

// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Expand All @@ -292,18 +300,20 @@ struct SubViewOpInterface
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
cf::AssertOp::create(
builder, loc, lastPosInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview runs out-of-bounds along dimension " +
std::to_string(i)));
generateErrorMessage(op,
"subview runs out-of-bounds along dimension " +
std::to_string(i)));
}
}
};

struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
void
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
function_ref<std::string(Operation *, StringRef)>
generateErrorMessage) const {
auto expandShapeOp = cast<ExpandShapeOp>(op);

// Verify that the expanded dim sizes are a product of the collapsed dim
Expand Down Expand Up @@ -333,9 +343,9 @@ struct ExpandShapeOpInterface
Value isModZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, mod,
arith::ConstantIndexOp::create(builder, loc, 0));
cf::AssertOp::create(builder, loc, isModZero,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "static result dims in reassoc group do not "
cf::AssertOp::create(
builder, loc, isModZero,
generateErrorMessage(op, "static result dims in reassoc group do not "
"divide src dim evenly"));
}
}
Expand Down
Loading