|
23 | 23 | #include "flang/Optimizer/Support/InternalNames.h" |
24 | 24 | #include "flang/Optimizer/Support/TypeCode.h" |
25 | 25 | #include "flang/Optimizer/Support/Utils.h" |
| 26 | +#include "flang/Optimizer/Transforms/CUFCommon.h" |
26 | 27 | #include "flang/Runtime/CUDA/descriptor.h" |
| 28 | +#include "flang/Runtime/CUDA/memory.h" |
27 | 29 | #include "flang/Runtime/allocator-registry-consts.h" |
28 | 30 | #include "flang/Runtime/descriptor-consts.h" |
29 | 31 | #include "flang/Semantics/runtime-type-info.h" |
@@ -1135,6 +1137,93 @@ convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy, |
1135 | 1137 | return result; |
1136 | 1138 | } |
1137 | 1139 |
|
| 1140 | +static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, |
| 1141 | + mlir::ConversionPatternRewriter &rewriter) { |
| 1142 | + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 1143 | + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) { |
| 1144 | + auto fn = flc.getFilename().str() + '\0'; |
| 1145 | + std::string globalName = fir::factory::uniqueCGIdent("cl", fn); |
| 1146 | + |
| 1147 | + if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) { |
| 1148 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
| 1149 | + } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) { |
| 1150 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
| 1151 | + } |
| 1152 | + |
| 1153 | + auto crtInsPt = rewriter.saveInsertionPoint(); |
| 1154 | + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); |
| 1155 | + auto arrayTy = mlir::LLVM::LLVMArrayType::get( |
| 1156 | + mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); |
| 1157 | + mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
| 1158 | + loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, |
| 1159 | + globalName, mlir::Attribute()); |
| 1160 | + |
| 1161 | + mlir::Region ®ion = globalOp.getInitializerRegion(); |
| 1162 | + mlir::Block *block = rewriter.createBlock(®ion); |
| 1163 | + rewriter.setInsertionPoint(block, block->begin()); |
| 1164 | + mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| 1165 | + loc, arrayTy, rewriter.getStringAttr(fn)); |
| 1166 | + rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue); |
| 1167 | + rewriter.restoreInsertionPoint(crtInsPt); |
| 1168 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, |
| 1169 | + globalOp.getName()); |
| 1170 | + } |
| 1171 | + return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy); |
| 1172 | +} |
| 1173 | + |
| 1174 | +static mlir::Value genSourceLine(mlir::Location loc, |
| 1175 | + mlir::ConversionPatternRewriter &rewriter) { |
| 1176 | + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) |
| 1177 | + return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
| 1178 | + flc.getLine()); |
| 1179 | + return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
| 1180 | +} |
| 1181 | + |
| 1182 | +static mlir::Value |
| 1183 | +genCUFAllocDescriptor(mlir::Location loc, |
| 1184 | + mlir::ConversionPatternRewriter &rewriter, |
| 1185 | + mlir::ModuleOp mod, fir::BaseBoxType boxTy, |
| 1186 | + const fir::LLVMTypeConverter &typeConverter) { |
| 1187 | + std::optional<mlir::DataLayout> dl = |
| 1188 | + fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); |
| 1189 | + if (!dl) |
| 1190 | + mlir::emitError(mod.getLoc(), |
| 1191 | + "module operation must carry a data layout attribute " |
| 1192 | + "to generate llvm IR from FIR"); |
| 1193 | + |
| 1194 | + mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); |
| 1195 | + mlir::Value sourceLine = genSourceLine(loc, rewriter); |
| 1196 | + |
| 1197 | + mlir::MLIRContext *ctx = mod.getContext(); |
| 1198 | + |
| 1199 | + mlir::LLVM::LLVMPointerType llvmPointerType = |
| 1200 | + mlir::LLVM::LLVMPointerType::get(ctx); |
| 1201 | + mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); |
| 1202 | + mlir::Type llvmIntPtrType = |
| 1203 | + mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); |
| 1204 | + auto fctTy = mlir::LLVM::LLVMFunctionType::get( |
| 1205 | + llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); |
| 1206 | + |
| 1207 | + auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( |
| 1208 | + RTNAME_STRING(CUFAllocDesciptor)); |
| 1209 | + auto funcFunc = |
| 1210 | + mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor)); |
| 1211 | + if (!llvmFunc && !funcFunc) |
| 1212 | + mlir::OpBuilder::atBlockEnd(mod.getBody()) |
| 1213 | + .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor), |
| 1214 | + fctTy); |
| 1215 | + |
| 1216 | + mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
| 1217 | + std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; |
| 1218 | + mlir::Value sizeInBytes = |
| 1219 | + genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); |
| 1220 | + llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; |
| 1221 | + return rewriter |
| 1222 | + .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), |
| 1223 | + args) |
| 1224 | + .getResult(); |
| 1225 | +} |
| 1226 | + |
1138 | 1227 | /// Common base class for embox to descriptor conversion. |
1139 | 1228 | template <typename OP> |
1140 | 1229 | struct EmboxCommonConversion : public fir::FIROpConversion<OP> { |
@@ -1548,15 +1637,24 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> { |
1548 | 1637 | mlir::Value |
1549 | 1638 | placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter, |
1550 | 1639 | mlir::Location loc, mlir::Type boxTy, |
1551 | | - mlir::Value boxValue) const { |
| 1640 | + mlir::Value boxValue, |
| 1641 | + bool needDeviceAllocation = false) const { |
1552 | 1642 | if (isInGlobalOp(rewriter)) |
1553 | 1643 | return boxValue; |
1554 | 1644 | mlir::Type llvmBoxTy = boxValue.getType(); |
1555 | | - auto alloca = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, |
1556 | | - defaultAlign, rewriter); |
1557 | | - auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, alloca); |
| 1645 | + mlir::Value storage; |
| 1646 | + if (needDeviceAllocation) { |
| 1647 | + auto mod = boxValue.getDefiningOp()->getParentOfType<mlir::ModuleOp>(); |
| 1648 | + auto baseBoxTy = mlir::dyn_cast<fir::BaseBoxType>(boxTy); |
| 1649 | + storage = |
| 1650 | + genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy()); |
| 1651 | + } else { |
| 1652 | + storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign, |
| 1653 | + rewriter); |
| 1654 | + } |
| 1655 | + auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, storage); |
1558 | 1656 | this->attachTBAATag(storeOp, boxTy, boxTy, nullptr); |
1559 | | - return alloca; |
| 1657 | + return storage; |
1560 | 1658 | } |
1561 | 1659 | }; |
1562 | 1660 |
|
@@ -1608,6 +1706,18 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> { |
1608 | 1706 | } |
1609 | 1707 | }; |
1610 | 1708 |
|
| 1709 | +static bool isDeviceAllocation(mlir::Value val) { |
| 1710 | + if (auto convertOp = |
| 1711 | + mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp())) |
| 1712 | + val = convertOp.getValue(); |
| 1713 | + if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp())) |
| 1714 | + if (callOp.getCallee() && |
| 1715 | + callOp.getCallee().value().getRootReference().getValue().starts_with( |
| 1716 | + RTNAME_STRING(CUFMemAlloc))) |
| 1717 | + return true; |
| 1718 | + return false; |
| 1719 | +} |
| 1720 | + |
1611 | 1721 | /// Create a generic box on a memory reference. |
1612 | 1722 | struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> { |
1613 | 1723 | using EmboxCommonConversion::EmboxCommonConversion; |
@@ -1791,9 +1901,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> { |
1791 | 1901 | dest = insertBaseAddress(rewriter, loc, dest, base); |
1792 | 1902 | if (fir::isDerivedTypeWithLenParams(boxTy)) |
1793 | 1903 | TODO(loc, "fir.embox codegen of derived with length parameters"); |
1794 | | - |
1795 | | - mlir::Value result = |
1796 | | - placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest); |
| 1904 | + mlir::Value result = placeInMemoryIfNotGlobalInit( |
| 1905 | + rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref())); |
1797 | 1906 | rewriter.replaceOp(xbox, result); |
1798 | 1907 | return mlir::success(); |
1799 | 1908 | } |
@@ -2971,93 +3080,6 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> { |
2971 | 3080 | } |
2972 | 3081 | }; |
2973 | 3082 |
|
2974 | | -static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, |
2975 | | - mlir::ConversionPatternRewriter &rewriter) { |
2976 | | - auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
2977 | | - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) { |
2978 | | - auto fn = flc.getFilename().str() + '\0'; |
2979 | | - std::string globalName = fir::factory::uniqueCGIdent("cl", fn); |
2980 | | - |
2981 | | - if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) { |
2982 | | - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
2983 | | - } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) { |
2984 | | - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
2985 | | - } |
2986 | | - |
2987 | | - auto crtInsPt = rewriter.saveInsertionPoint(); |
2988 | | - rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); |
2989 | | - auto arrayTy = mlir::LLVM::LLVMArrayType::get( |
2990 | | - mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); |
2991 | | - mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
2992 | | - loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, |
2993 | | - globalName, mlir::Attribute()); |
2994 | | - |
2995 | | - mlir::Region ®ion = globalOp.getInitializerRegion(); |
2996 | | - mlir::Block *block = rewriter.createBlock(®ion); |
2997 | | - rewriter.setInsertionPoint(block, block->begin()); |
2998 | | - mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>( |
2999 | | - loc, arrayTy, rewriter.getStringAttr(fn)); |
3000 | | - rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue); |
3001 | | - rewriter.restoreInsertionPoint(crtInsPt); |
3002 | | - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, |
3003 | | - globalOp.getName()); |
3004 | | - } |
3005 | | - return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy); |
3006 | | -} |
3007 | | - |
3008 | | -static mlir::Value genSourceLine(mlir::Location loc, |
3009 | | - mlir::ConversionPatternRewriter &rewriter) { |
3010 | | - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) |
3011 | | - return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
3012 | | - flc.getLine()); |
3013 | | - return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
3014 | | -} |
3015 | | - |
3016 | | -static mlir::Value |
3017 | | -genCUFAllocDescriptor(mlir::Location loc, |
3018 | | - mlir::ConversionPatternRewriter &rewriter, |
3019 | | - mlir::ModuleOp mod, fir::BaseBoxType boxTy, |
3020 | | - const fir::LLVMTypeConverter &typeConverter) { |
3021 | | - std::optional<mlir::DataLayout> dl = |
3022 | | - fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); |
3023 | | - if (!dl) |
3024 | | - mlir::emitError(mod.getLoc(), |
3025 | | - "module operation must carry a data layout attribute " |
3026 | | - "to generate llvm IR from FIR"); |
3027 | | - |
3028 | | - mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); |
3029 | | - mlir::Value sourceLine = genSourceLine(loc, rewriter); |
3030 | | - |
3031 | | - mlir::MLIRContext *ctx = mod.getContext(); |
3032 | | - |
3033 | | - mlir::LLVM::LLVMPointerType llvmPointerType = |
3034 | | - mlir::LLVM::LLVMPointerType::get(ctx); |
3035 | | - mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); |
3036 | | - mlir::Type llvmIntPtrType = |
3037 | | - mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); |
3038 | | - auto fctTy = mlir::LLVM::LLVMFunctionType::get( |
3039 | | - llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); |
3040 | | - |
3041 | | - auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( |
3042 | | - RTNAME_STRING(CUFAllocDesciptor)); |
3043 | | - auto funcFunc = |
3044 | | - mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor)); |
3045 | | - if (!llvmFunc && !funcFunc) |
3046 | | - mlir::OpBuilder::atBlockEnd(mod.getBody()) |
3047 | | - .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor), |
3048 | | - fctTy); |
3049 | | - |
3050 | | - mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
3051 | | - std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; |
3052 | | - mlir::Value sizeInBytes = |
3053 | | - genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); |
3054 | | - llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; |
3055 | | - return rewriter |
3056 | | - .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), |
3057 | | - args) |
3058 | | - .getResult(); |
3059 | | -} |
3060 | | - |
3061 | 3083 | /// `fir.load` --> `llvm.load` |
3062 | 3084 | struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> { |
3063 | 3085 | using FIROpConversion::FIROpConversion; |
|
0 commit comments