@@ -17,16 +17,46 @@ namespace mlir {
1717#include " mlir/Transforms/Passes.h.inc"
1818} // namespace mlir
1919
20+ #define DEBUG_TYPE " generate-runtime-verification"
21+ #define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
22+ #define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
23+
2024using namespace mlir ;
2125
26+ static bool defaultShouldHandleVerifiableOpFn (RuntimeVerifiableOpInterface op) {
27+ // By default, all verifiable ops are considered
28+ return true ;
29+ }
30+
2231namespace {
2332struct GenerateRuntimeVerificationPass
2433 : public impl::GenerateRuntimeVerificationBase<
2534 GenerateRuntimeVerificationPass> {
35+
36+ GenerateRuntimeVerificationPass ();
37+ GenerateRuntimeVerificationPass (const GenerateRuntimeVerificationPass &) =
38+ default ;
39+ GenerateRuntimeVerificationPass (
40+ std::function<bool (RuntimeVerifiableOpInterface)>
41+ shouldHandleVerifiableOpFn);
42+
2643 void runOnOperation () override ;
44+
45+ private:
46+ // A filter function to select verifiable ops to generate verification for.
47+ // If empty, all verifiable ops are considered.
48+ std::function<bool (RuntimeVerifiableOpInterface)> shouldHandleVerifiableOpFn;
2749};
2850} // namespace
2951
52+ GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass ()
53+ : shouldHandleVerifiableOpFn(defaultShouldHandleVerifiableOpFn) {}
54+
55+ GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass (
56+ std::function<bool (RuntimeVerifiableOpInterface)>
57+ shouldHandleVerifiableOpFn)
58+ : shouldHandleVerifiableOpFn(std::move(shouldHandleVerifiableOpFn)) {}
59+
3060void GenerateRuntimeVerificationPass::runOnOperation () {
3161 // The implementation of the RuntimeVerifiableOpInterface may create ops that
3262 // can be verified. We don't want to generate verification for IR that
@@ -38,11 +68,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
3868
3969 OpBuilder builder (getOperation ()->getContext ());
4070 for (RuntimeVerifiableOpInterface verifiableOp : ops) {
41- builder.setInsertionPoint (verifiableOp);
42- verifiableOp.generateRuntimeVerification (builder, verifiableOp.getLoc ());
43- };
71+ if (shouldHandleVerifiableOpFn (verifiableOp)) {
72+ builder.setInsertionPoint (verifiableOp);
73+ verifiableOp.generateRuntimeVerification (builder, verifiableOp.getLoc ());
74+ } else {
75+ LDBG (" Skipping operation: " << verifiableOp.getOperation ());
76+ }
77+ }
4478}
4579
4680std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass () {
4781 return std::make_unique<GenerateRuntimeVerificationPass>();
4882}
83+
84+ std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass (
85+ std::function<bool (RuntimeVerifiableOpInterface)>
86+ shouldHandleVerifiableOpFn) {
87+ return std::make_unique<GenerateRuntimeVerificationPass>(
88+ std::move (shouldHandleVerifiableOpFn));
89+ }
0 commit comments