1111#include " flang/Optimizer/Dialect/FIRAttr.h"
1212#include " flang/Optimizer/Dialect/FIRDialect.h"
1313#include " flang/Optimizer/Dialect/FIROpsSupport.h"
14+ #include " flang/Optimizer/Transforms/CUFCommon.h"
1415#include " flang/Runtime/entry-names.h"
1516#include " mlir/Dialect/GPU/IR/GPUDialect.h"
1617#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -24,8 +25,6 @@ namespace fir {
2425
2526namespace {
2627
27- static constexpr llvm::StringRef cudaModName{" cuda_device_mod" };
28-
2928static constexpr llvm::StringRef cudaFortranCtorName{
3029 " __cudaFortranConstructor" };
3130
@@ -60,15 +59,15 @@ struct CUFAddConstructor
6059 builder.create <mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
6160
6261 // Register kernels
63- auto gpuMod = symTab.lookup <mlir::gpu::GPUModuleOp>(cudaModName );
62+ auto gpuMod = symTab.lookup <mlir::gpu::GPUModuleOp>(cudaDeviceModuleName );
6463 if (gpuMod) {
6564 auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get (ctx);
6665 auto registeredMod = builder.create <cuf::RegisterModuleOp>(
6766 loc, llvmPtrTy, mlir::SymbolRefAttr::get (ctx, gpuMod.getName ()));
6867 for (auto func : gpuMod.getOps <mlir::gpu::GPUFuncOp>()) {
6968 if (func.isKernel ()) {
7069 auto kernelName = mlir::SymbolRefAttr::get (
71- builder.getStringAttr (cudaModName ),
70+ builder.getStringAttr (cudaDeviceModuleName ),
7271 {mlir::SymbolRefAttr::get (builder.getContext (), func.getName ())});
7372 builder.create <cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
7473 }
0 commit comments