1212#include " flang/Optimizer/Dialect/FIRDialect.h"
1313#include " flang/Optimizer/Dialect/FIROpsSupport.h"
1414#include " flang/Runtime/entry-names.h"
15+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
1516#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1617#include " mlir/Pass/Pass.h"
1718#include " llvm/ADT/SmallVector.h"
@@ -23,6 +24,8 @@ namespace fir {
2324
2425namespace {
2526
27+ static constexpr llvm::StringRef cudaModName{" cuda_device_mod" };
28+
2629static constexpr llvm::StringRef cudaFortranCtorName{
2730 " __cudaFortranConstructor" };
2831
@@ -31,6 +34,7 @@ struct CUFAddConstructor
3134
3235 void runOnOperation () override {
3336 mlir::ModuleOp mod = getOperation ();
37+ mlir::SymbolTable symTab (mod);
3438 mlir::OpBuilder builder{mod.getBodyRegion ()};
3539 builder.setInsertionPointToEnd (mod.getBody ());
3640 mlir::Location loc = mod.getLoc ();
@@ -48,13 +52,25 @@ struct CUFAddConstructor
4852 mod.getContext (), RTNAME_STRING (CUFRegisterAllocator));
4953 builder.setInsertionPointToEnd (mod.getBody ());
5054
51- // Create the constructor function that cal CUFRegisterAllocator.
52- builder.setInsertionPointToEnd (mod.getBody ());
55+ // Create the constructor function that call CUFRegisterAllocator.
5356 auto func = builder.create <mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
5457 funcTy);
5558 func.setLinkage (mlir::LLVM::Linkage::Internal);
5659 builder.setInsertionPointToStart (func.addEntryBlock (builder));
5760 builder.create <mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
61+
62+ // Register kernels
63+ auto gpuMod = symTab.lookup <mlir::gpu::GPUModuleOp>(cudaModName);
64+ if (gpuMod) {
65+ for (auto func : gpuMod.getOps <mlir::gpu::GPUFuncOp>()) {
66+ if (func.isKernel ()) {
67+ auto kernelName = mlir::SymbolRefAttr::get (
68+ builder.getStringAttr (cudaModName),
69+ {mlir::SymbolRefAttr::get (builder.getContext (), func.getName ())});
70+ builder.create <cuf::RegisterKernelOp>(loc, kernelName);
71+ }
72+ }
73+ }
5874 builder.create <mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
5975
6076 // Create the llvm.global_ctor with the function.
0 commit comments