66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " flang/Optimizer/Builder/BoxValue.h"
910#include " flang/Optimizer/Builder/FIRBuilder.h"
11+ #include " flang/Optimizer/Builder/Runtime/RTBuilder.h"
12+ #include " flang/Optimizer/Builder/Todo.h"
13+ #include " flang/Optimizer/CodeGen/Target.h"
1014#include " flang/Optimizer/Dialect/CUF/CUFOps.h"
1115#include " flang/Optimizer/Dialect/FIRAttr.h"
1216#include " flang/Optimizer/Dialect/FIRDialect.h"
17+ #include " flang/Optimizer/Dialect/FIROps.h"
1318#include " flang/Optimizer/Dialect/FIROpsSupport.h"
19+ #include " flang/Optimizer/Support/DataLayout.h"
1420#include " flang/Optimizer/Transforms/CUFCommon.h"
21+ #include " flang/Runtime/CUDA/registration.h"
1522#include " flang/Runtime/entry-names.h"
1623#include " mlir/Dialect/GPU/IR/GPUDialect.h"
1724#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
25+ #include " mlir/IR/Value.h"
1826#include " mlir/Pass/Pass.h"
1927#include " llvm/ADT/SmallVector.h"
2028
@@ -23,6 +31,8 @@ namespace fir {
2331#include " flang/Optimizer/Transforms/Passes.h.inc"
2432} // namespace fir
2533
34+ using namespace Fortran ::runtime::cuda;
35+
2636namespace {
2737
2838static constexpr llvm::StringRef cudaFortranCtorName{
@@ -34,13 +44,23 @@ struct CUFAddConstructor
3444 void runOnOperation () override {
3545 mlir::ModuleOp mod = getOperation ();
3646 mlir::SymbolTable symTab (mod);
37- mlir::OpBuilder builder{mod.getBodyRegion ()};
47+ mlir::OpBuilder opBuilder{mod.getBodyRegion ()};
48+ fir::FirOpBuilder builder (opBuilder, mod);
49+ fir::KindMapping kindMap{fir::getKindMapping (mod)};
3850 builder.setInsertionPointToEnd (mod.getBody ());
3951 mlir::Location loc = mod.getLoc ();
4052 auto *ctx = mod.getContext ();
4153 auto voidTy = mlir::LLVM::LLVMVoidType::get (ctx);
54+ auto idxTy = builder.getIndexType ();
4255 auto funcTy =
4356 mlir::LLVM::LLVMFunctionType::get (voidTy, {}, /* isVarArg=*/ false );
57+ std::optional<mlir::DataLayout> dl =
58+ fir::support::getOrSetDataLayout (mod, /* allowDefaultLayout=*/ false );
59+ if (!dl) {
60+ mlir::emitError (mod.getLoc (),
61+ " data layout attribute is required to perform " +
62+ getName () + " pass" );
63+ }
4464
4565 // Symbol reference to CUFRegisterAllocator.
4666 builder.setInsertionPointToEnd (mod.getBody ());
@@ -58,12 +78,13 @@ struct CUFAddConstructor
5878 builder.setInsertionPointToStart (func.addEntryBlock (builder));
5979 builder.create <mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
6080
61- // Register kernels
6281 auto gpuMod = symTab.lookup <mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
6382 if (gpuMod) {
6483 auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get (ctx);
6584 auto registeredMod = builder.create <cuf::RegisterModuleOp>(
6685 loc, llvmPtrTy, mlir::SymbolRefAttr::get (ctx, gpuMod.getName ()));
86+
87+ // Register kernels
6788 for (auto func : gpuMod.getOps <mlir::gpu::GPUFuncOp>()) {
6889 if (func.isKernel ()) {
6990 auto kernelName = mlir::SymbolRefAttr::get (
@@ -72,12 +93,55 @@ struct CUFAddConstructor
7293 builder.create <cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
7394 }
7495 }
96+
97+ // Register variables
98+ for (fir::GlobalOp globalOp : mod.getOps <fir::GlobalOp>()) {
99+ auto attr = globalOp.getDataAttrAttr ();
100+ if (!attr)
101+ continue ;
102+
103+ mlir::func::FuncOp func;
104+ switch (attr.getValue ()) {
105+ case cuf::DataAttribute::Device:
106+ case cuf::DataAttribute::Constant: {
107+ func = fir::runtime::getRuntimeFunc<mkRTKey (CUFRegisterVariable)>(
108+ loc, builder);
109+ auto fTy = func.getFunctionType ();
110+
111+ // Global variable name
112+ std::string gblNameStr = globalOp.getSymbol ().getValue ().str ();
113+ gblNameStr += ' \0 ' ;
114+ mlir::Value gblName = fir::getBase (
115+ fir::factory::createStringLiteral (builder, loc, gblNameStr));
116+
117+ // Global variable size
118+ auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash (
119+ loc, globalOp.getType (), *dl, kindMap);
120+ auto size =
121+ builder.createIntegerConstant (loc, idxTy, sizeAndAlign.first );
122+
123+ // Global variable address
124+ mlir::Value addr = builder.create <fir::AddrOfOp>(
125+ loc, globalOp.resultType (), globalOp.getSymbol ());
126+
127+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
128+ builder, loc, fTy , registeredMod, addr, gblName, size)};
129+ builder.create <fir::CallOp>(loc, func, args);
130+ } break ;
131+ case cuf::DataAttribute::Managed:
132+ TODO (loc, " registration of managed variables" );
133+ default :
134+ break ;
135+ }
136+ if (!func)
137+ continue ;
138+ }
75139 }
76140 builder.create <mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
77141
78142 // Create the llvm.global_ctor with the function.
79- // TODO: We might want to have a utility that retrieve it if already created
80- // and adds new functions.
143+ // TODO: We might want to have a utility that retrieve it if already
144+ // created and adds new functions.
81145 builder.setInsertionPointToEnd (mod.getBody ());
82146 llvm::SmallVector<mlir::Attribute> funcs;
83147 funcs.push_back (
0 commit comments