Skip to content

Commit b3e02d1

Browse files
committed
[mlir] Add filtering callback to GenerateRuntimeVerification pass
Users would be able to create this pass and attach to it a custom callback function to filter out unwanted operations.
1 parent be75ded commit b3e02d1

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

mlir/include/mlir/Transforms/Passes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
namespace mlir {
2727

2828
class GreedyRewriteConfig;
29+
class RuntimeVerifiableOpInterface;
2930

3031
//===----------------------------------------------------------------------===//
3132
// Passes
@@ -77,6 +78,13 @@ std::unique_ptr<Pass> createPrintIRPass(const PrintIRPassOptions & = {});
7778
/// Creates a pass that generates IR to verify ops at runtime.
7879
std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
7980

81+
/// Create an instance of the generate runtime verification pass, and
82+
/// use the provided filter function to skip certain verifiable ops.
83+
/// The default implementation does not filter any ops.
84+
std::unique_ptr<Pass> createGenerateRuntimeVerificationPass(
85+
std::function<bool(RuntimeVerifiableOpInterface)>
86+
shouldHandleVerifiableOpFn);
87+
8088
/// Creates a loop invariant code motion pass that hoists loop invariant
8189
/// instructions out of the loop.
8290
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();

mlir/lib/Transforms/GenerateRuntimeVerification.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,46 @@ namespace mlir {
1717
#include "mlir/Transforms/Passes.h.inc"
1818
} // namespace mlir
1919

20+
#define DEBUG_TYPE "generate-runtime-verification"
21+
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
22+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
23+
2024
using namespace mlir;
2125

26+
static bool defaultShouldHandleVerifiableOpFn(RuntimeVerifiableOpInterface op) {
27+
// By default, all verifiable ops are considered
28+
return true;
29+
}
30+
2231
namespace {
2332
struct GenerateRuntimeVerificationPass
2433
: public impl::GenerateRuntimeVerificationBase<
2534
GenerateRuntimeVerificationPass> {
35+
36+
GenerateRuntimeVerificationPass();
37+
GenerateRuntimeVerificationPass(const GenerateRuntimeVerificationPass &) =
38+
default;
39+
GenerateRuntimeVerificationPass(
40+
std::function<bool(RuntimeVerifiableOpInterface)>
41+
shouldHandleVerifiableOpFn);
42+
2643
void runOnOperation() override;
44+
45+
private:
46+
// A filter function to select verifiable ops to generate verification for.
47+
// If empty, all verifiable ops are considered.
48+
std::function<bool(RuntimeVerifiableOpInterface)> shouldHandleVerifiableOpFn;
2749
};
2850
} // namespace
2951

52+
GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass()
53+
: shouldHandleVerifiableOpFn(defaultShouldHandleVerifiableOpFn) {}
54+
55+
GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass(
56+
std::function<bool(RuntimeVerifiableOpInterface)>
57+
shouldHandleVerifiableOpFn)
58+
: shouldHandleVerifiableOpFn(std::move(shouldHandleVerifiableOpFn)) {}
59+
3060
void GenerateRuntimeVerificationPass::runOnOperation() {
3161
// The implementation of the RuntimeVerifiableOpInterface may create ops that
3262
// can be verified. We don't want to generate verification for IR that
@@ -38,11 +68,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
3868

3969
OpBuilder builder(getOperation()->getContext());
4070
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
41-
builder.setInsertionPoint(verifiableOp);
42-
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
43-
};
71+
if (shouldHandleVerifiableOpFn(verifiableOp)) {
72+
builder.setInsertionPoint(verifiableOp);
73+
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
74+
} else {
75+
LDBG("Skipping operation: " << verifiableOp.getOperation());
76+
}
77+
}
4478
}
4579

4680
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
4781
return std::make_unique<GenerateRuntimeVerificationPass>();
4882
}
83+
84+
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass(
85+
std::function<bool(RuntimeVerifiableOpInterface)>
86+
shouldHandleVerifiableOpFn) {
87+
return std::make_unique<GenerateRuntimeVerificationPass>(
88+
std::move(shouldHandleVerifiableOpFn));
89+
}

0 commit comments

Comments
 (0)