|
11 | 11 | #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" |
12 | 12 | #include "flang/Optimizer/Builder/Todo.h" |
13 | 13 | #include "flang/Optimizer/CodeGen/Target.h" |
| 14 | +#include "flang/Optimizer/CodeGen/TypeConverter.h" |
14 | 15 | #include "flang/Optimizer/Dialect/CUF/CUFOps.h" |
15 | 16 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
16 | 17 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
17 | 18 | #include "flang/Optimizer/Dialect/FIROps.h" |
18 | 19 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
| 20 | +#include "flang/Optimizer/Dialect/FIRType.h" |
19 | 21 | #include "flang/Optimizer/Support/DataLayout.h" |
20 | 22 | #include "flang/Optimizer/Transforms/CUFCommon.h" |
21 | 23 | #include "flang/Runtime/CUDA/registration.h" |
@@ -84,6 +86,8 @@ struct CUFAddConstructor |
84 | 86 | auto registeredMod = builder.create<cuf::RegisterModuleOp>( |
85 | 87 | loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName())); |
86 | 88 |
|
| 89 | + fir::LLVMTypeConverter typeConverter(mod, /*applyTBAA=*/false, |
| 90 | + /*forceUnifiedTBAATree=*/false, *dl); |
87 | 91 | // Register kernels |
88 | 92 | for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) { |
89 | 93 | if (func.isKernel()) { |
@@ -115,17 +119,25 @@ struct CUFAddConstructor |
115 | 119 | fir::factory::createStringLiteral(builder, loc, gblNameStr)); |
116 | 120 |
|
117 | 121 | // Global variable size |
118 | | - auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash( |
119 | | - loc, globalOp.getType(), *dl, kindMap); |
120 | | - auto size = |
121 | | - builder.createIntegerConstant(loc, idxTy, sizeAndAlign.first); |
| 122 | + std::optional<uint64_t> size; |
| 123 | + if (auto boxTy = |
| 124 | + mlir::dyn_cast<fir::BaseBoxType>(globalOp.getType())) { |
| 125 | + mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
| 126 | + size = dl->getTypeSizeInBits(structTy) / 8; |
| 127 | + } |
| 128 | + if (!size) { |
| 129 | + size = fir::getTypeSizeAndAlignmentOrCrash(loc, globalOp.getType(), |
| 130 | + *dl, kindMap) |
| 131 | + .first; |
| 132 | + } |
| 133 | + auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size); |
122 | 134 |
|
123 | 135 | // Global variable address |
124 | 136 | mlir::Value addr = builder.create<fir::AddrOfOp>( |
125 | 137 | loc, globalOp.resultType(), globalOp.getSymbol()); |
126 | 138 |
|
127 | 139 | llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( |
128 | | - builder, loc, fTy, registeredMod, addr, gblName, size)}; |
| 140 | + builder, loc, fTy, registeredMod, addr, gblName, sizeVal)}; |
129 | 141 | builder.create<fir::CallOp>(loc, func, args); |
130 | 142 | } break; |
131 | 143 | case cuf::DataAttribute::Managed: |
|
0 commit comments