@@ -86,8 +86,8 @@ struct FunctionCallBuilder {
86
86
mlir::LLVM::LLVMFunctionType functionType;
87
87
};
88
88
89
- static const char * kEventCountAttrName = " gpu.event_count" ;
90
- static const char * kEventIndexAttrName = " gpu.event_index" ;
89
+ static constexpr llvm::StringLiteral kEventCountAttrName ( " gpu.event_count" ) ;
90
+ static constexpr llvm::StringLiteral kEventIndexAttrName ( " gpu.event_index" ) ;
91
91
92
92
template <typename OpTy>
93
93
class ConvertOpToGpuRuntimeCallPattern
@@ -559,12 +559,13 @@ class ConvertGpuKernelLaunchPattern
559
559
eventIndexVar,
560
560
// clang-format on
561
561
};
562
- auto res = launchKernelCallBuilder.create (loc, rewriter, params);
562
+ auto event =
563
+ launchKernelCallBuilder.create (loc, rewriter, params)->getResult (0 );
563
564
if (op.getNumResults () == 0 ) {
565
+ waitEventCallBuilder.create (loc, rewriter, event);
564
566
rewriter.eraseOp (op);
565
567
} else {
566
- assert (res.getNumResults () == op.getNumResults ());
567
- rewriter.replaceOp (op, res.getResults ());
568
+ rewriter.replaceOp (op, event);
568
569
}
569
570
return mlir::success ();
570
571
}
@@ -659,11 +660,12 @@ class ConvertGpuAllocPattern
659
660
}
660
661
661
662
mlir::Value resMemref = memrefDesc;
663
+ mlir::Value event = rewriter.create <mlir::LLVM::ExtractValueOp>(
664
+ loc, llvmPointerType, res, rewriter.getI64ArrayAttr (2 ));
662
665
if (op.getNumResults () == 1 ) {
666
+ waitEventCallBuilder.create (loc, rewriter, event);
663
667
rewriter.replaceOp (op, resMemref);
664
668
} else {
665
- auto event = rewriter.create <mlir::LLVM::ExtractValueOp>(
666
- loc, llvmPointerType, res, rewriter.getI64ArrayAttr (2 ));
667
669
mlir::Value vals[] = {
668
670
resMemref,
669
671
event,
@@ -782,16 +784,16 @@ struct EnumerateEventsPass
782
784
void runOnOperation () override {
783
785
auto mod = getOperation ();
784
786
int64_t eventCount = 0 ;
785
- auto intType = mlir::IntegerType::get (&getContext (), 64 );
787
+ auto *ctx = &getContext ();
788
+ auto intType = mlir::IntegerType::get (ctx, 64 );
789
+ auto indexAttrName = mlir::StringAttr::get (ctx, kEventIndexAttrName );
790
+ auto countAttrName = mlir::StringAttr::get (ctx, kEventCountAttrName );
786
791
mod.walk ([&](mlir::gpu::AsyncOpInterface op) {
787
- if (op.getAsyncToken ()) {
788
- op->setAttr (kEventIndexAttrName ,
789
- mlir::IntegerAttr::get (intType, eventCount));
790
- ++eventCount;
791
- }
792
+ // if (op.getAsyncToken())
793
+ op->setAttr (indexAttrName, mlir::IntegerAttr::get (intType, eventCount));
794
+ ++eventCount;
792
795
});
793
- mod->setAttr (kEventCountAttrName ,
794
- mlir::IntegerAttr::get (intType, eventCount));
796
+ mod->setAttr (countAttrName, mlir::IntegerAttr::get (intType, eventCount));
795
797
}
796
798
};
797
799
0 commit comments