|
24 | 24 | #include "flang/Optimizer/Support/TypeCode.h" |
25 | 25 | #include "flang/Optimizer/Support/Utils.h" |
26 | 26 | #include "flang/Runtime/CUDA/descriptor.h" |
| 27 | +#include "flang/Runtime/CUDA/memory.h" |
27 | 28 | #include "flang/Runtime/allocator-registry-consts.h" |
28 | 29 | #include "flang/Runtime/descriptor-consts.h" |
29 | 30 | #include "flang/Semantics/runtime-type-info.h" |
@@ -1141,6 +1142,93 @@ convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy, |
1141 | 1142 | return result; |
1142 | 1143 | } |
1143 | 1144 |
|
| 1145 | +static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, |
| 1146 | + mlir::ConversionPatternRewriter &rewriter) { |
| 1147 | + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 1148 | + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) { |
| 1149 | + auto fn = flc.getFilename().str() + '\0'; |
| 1150 | + std::string globalName = fir::factory::uniqueCGIdent("cl", fn); |
| 1151 | + |
| 1152 | + if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) { |
| 1153 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
| 1154 | + } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) { |
| 1155 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
| 1156 | + } |
| 1157 | + |
| 1158 | + auto crtInsPt = rewriter.saveInsertionPoint(); |
| 1159 | + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); |
| 1160 | + auto arrayTy = mlir::LLVM::LLVMArrayType::get( |
| 1161 | + mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); |
| 1162 | + mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
| 1163 | + loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, |
| 1164 | + globalName, mlir::Attribute()); |
| 1165 | + |
| 1166 | + mlir::Region ®ion = globalOp.getInitializerRegion(); |
| 1167 | + mlir::Block *block = rewriter.createBlock(®ion); |
| 1168 | + rewriter.setInsertionPoint(block, block->begin()); |
| 1169 | + mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| 1170 | + loc, arrayTy, rewriter.getStringAttr(fn)); |
| 1171 | + rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue); |
| 1172 | + rewriter.restoreInsertionPoint(crtInsPt); |
| 1173 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, |
| 1174 | + globalOp.getName()); |
| 1175 | + } |
| 1176 | + return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy); |
| 1177 | +} |
| 1178 | + |
| 1179 | +static mlir::Value genSourceLine(mlir::Location loc, |
| 1180 | + mlir::ConversionPatternRewriter &rewriter) { |
| 1181 | + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) |
| 1182 | + return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
| 1183 | + flc.getLine()); |
| 1184 | + return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
| 1185 | +} |
| 1186 | + |
| 1187 | +static mlir::Value |
| 1188 | +genCUFAllocDescriptor(mlir::Location loc, |
| 1189 | + mlir::ConversionPatternRewriter &rewriter, |
| 1190 | + mlir::ModuleOp mod, fir::BaseBoxType boxTy, |
| 1191 | + const fir::LLVMTypeConverter &typeConverter) { |
| 1192 | + std::optional<mlir::DataLayout> dl = |
| 1193 | + fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); |
| 1194 | + if (!dl) |
| 1195 | + mlir::emitError(mod.getLoc(), |
| 1196 | + "module operation must carry a data layout attribute " |
| 1197 | + "to generate llvm IR from FIR"); |
| 1198 | + |
| 1199 | + mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); |
| 1200 | + mlir::Value sourceLine = genSourceLine(loc, rewriter); |
| 1201 | + |
| 1202 | + mlir::MLIRContext *ctx = mod.getContext(); |
| 1203 | + |
| 1204 | + mlir::LLVM::LLVMPointerType llvmPointerType = |
| 1205 | + mlir::LLVM::LLVMPointerType::get(ctx); |
| 1206 | + mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); |
| 1207 | + mlir::Type llvmIntPtrType = |
| 1208 | + mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); |
| 1209 | + auto fctTy = mlir::LLVM::LLVMFunctionType::get( |
| 1210 | + llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); |
| 1211 | + |
| 1212 | + auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( |
| 1213 | + RTNAME_STRING(CUFAllocDesciptor)); |
| 1214 | + auto funcFunc = |
| 1215 | + mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor)); |
| 1216 | + if (!llvmFunc && !funcFunc) |
| 1217 | + mlir::OpBuilder::atBlockEnd(mod.getBody()) |
| 1218 | + .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor), |
| 1219 | + fctTy); |
| 1220 | + |
| 1221 | + mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
| 1222 | + std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; |
| 1223 | + mlir::Value sizeInBytes = |
| 1224 | + genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); |
| 1225 | + llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; |
| 1226 | + return rewriter |
| 1227 | + .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), |
| 1228 | + args) |
| 1229 | + .getResult(); |
| 1230 | +} |
| 1231 | + |
1144 | 1232 | /// Common base class for embox to descriptor conversion. |
1145 | 1233 | template <typename OP> |
1146 | 1234 | struct EmboxCommonConversion : public fir::FIROpConversion<OP> { |
@@ -1554,15 +1642,24 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> { |
1554 | 1642 | mlir::Value |
1555 | 1643 | placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter, |
1556 | 1644 | mlir::Location loc, mlir::Type boxTy, |
1557 | | - mlir::Value boxValue) const { |
| 1645 | + mlir::Value boxValue, |
| 1646 | + bool needDeviceAllocation = false) const { |
1558 | 1647 | if (isInGlobalOp(rewriter)) |
1559 | 1648 | return boxValue; |
1560 | 1649 | mlir::Type llvmBoxTy = boxValue.getType(); |
1561 | | - auto alloca = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, |
1562 | | - defaultAlign, rewriter); |
1563 | | - auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, alloca); |
| 1650 | + mlir::Value storage; |
| 1651 | + if (needDeviceAllocation) { |
| 1652 | + auto mod = boxValue.getDefiningOp()->getParentOfType<mlir::ModuleOp>(); |
| 1653 | + auto baseBoxTy = mlir::dyn_cast<fir::BaseBoxType>(boxTy); |
| 1654 | + storage = |
| 1655 | + genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy()); |
| 1656 | + } else { |
| 1657 | + storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign, |
| 1658 | + rewriter); |
| 1659 | + } |
| 1660 | + auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, storage); |
1564 | 1661 | this->attachTBAATag(storeOp, boxTy, boxTy, nullptr); |
1565 | | - return alloca; |
| 1662 | + return storage; |
1566 | 1663 | } |
1567 | 1664 | }; |
1568 | 1665 |
|
@@ -1614,6 +1711,18 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> { |
1614 | 1711 | } |
1615 | 1712 | }; |
1616 | 1713 |
|
| 1714 | +static bool isDeviceAllocation(mlir::Value val) { |
| 1715 | + if (auto convertOp = |
| 1716 | + mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp())) |
| 1717 | + val = convertOp.getValue(); |
| 1718 | + if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp())) |
| 1719 | + if (callOp.getCallee() && |
| 1720 | + callOp.getCallee().value().getRootReference().getValue().starts_with( |
| 1721 | + RTNAME_STRING(CUFMemAlloc))) |
| 1722 | + return true; |
| 1723 | + return false; |
| 1724 | +} |
| 1725 | + |
1617 | 1726 | /// Create a generic box on a memory reference. |
1618 | 1727 | struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> { |
1619 | 1728 | using EmboxCommonConversion::EmboxCommonConversion; |
@@ -1797,9 +1906,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> { |
1797 | 1906 | dest = insertBaseAddress(rewriter, loc, dest, base); |
1798 | 1907 | if (fir::isDerivedTypeWithLenParams(boxTy)) |
1799 | 1908 | TODO(loc, "fir.embox codegen of derived with length parameters"); |
1800 | | - |
1801 | | - mlir::Value result = |
1802 | | - placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest); |
| 1909 | + mlir::Value result = placeInMemoryIfNotGlobalInit( |
| 1910 | + rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref())); |
1803 | 1911 | rewriter.replaceOp(xbox, result); |
1804 | 1912 | return mlir::success(); |
1805 | 1913 | } |
@@ -2977,93 +3085,6 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> { |
2977 | 3085 | } |
2978 | 3086 | }; |
2979 | 3087 |
|
2980 | | -static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, |
2981 | | - mlir::ConversionPatternRewriter &rewriter) { |
2982 | | - auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
2983 | | - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) { |
2984 | | - auto fn = flc.getFilename().str() + '\0'; |
2985 | | - std::string globalName = fir::factory::uniqueCGIdent("cl", fn); |
2986 | | - |
2987 | | - if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) { |
2988 | | - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
2989 | | - } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) { |
2990 | | - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
2991 | | - } |
2992 | | - |
2993 | | - auto crtInsPt = rewriter.saveInsertionPoint(); |
2994 | | - rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); |
2995 | | - auto arrayTy = mlir::LLVM::LLVMArrayType::get( |
2996 | | - mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); |
2997 | | - mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
2998 | | - loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, |
2999 | | - globalName, mlir::Attribute()); |
3000 | | - |
3001 | | - mlir::Region ®ion = globalOp.getInitializerRegion(); |
3002 | | - mlir::Block *block = rewriter.createBlock(®ion); |
3003 | | - rewriter.setInsertionPoint(block, block->begin()); |
3004 | | - mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>( |
3005 | | - loc, arrayTy, rewriter.getStringAttr(fn)); |
3006 | | - rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue); |
3007 | | - rewriter.restoreInsertionPoint(crtInsPt); |
3008 | | - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, |
3009 | | - globalOp.getName()); |
3010 | | - } |
3011 | | - return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy); |
3012 | | -} |
3013 | | - |
3014 | | -static mlir::Value genSourceLine(mlir::Location loc, |
3015 | | - mlir::ConversionPatternRewriter &rewriter) { |
3016 | | - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) |
3017 | | - return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
3018 | | - flc.getLine()); |
3019 | | - return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
3020 | | -} |
3021 | | - |
3022 | | -static mlir::Value |
3023 | | -genCUFAllocDescriptor(mlir::Location loc, |
3024 | | - mlir::ConversionPatternRewriter &rewriter, |
3025 | | - mlir::ModuleOp mod, fir::BaseBoxType boxTy, |
3026 | | - const fir::LLVMTypeConverter &typeConverter) { |
3027 | | - std::optional<mlir::DataLayout> dl = |
3028 | | - fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); |
3029 | | - if (!dl) |
3030 | | - mlir::emitError(mod.getLoc(), |
3031 | | - "module operation must carry a data layout attribute " |
3032 | | - "to generate llvm IR from FIR"); |
3033 | | - |
3034 | | - mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); |
3035 | | - mlir::Value sourceLine = genSourceLine(loc, rewriter); |
3036 | | - |
3037 | | - mlir::MLIRContext *ctx = mod.getContext(); |
3038 | | - |
3039 | | - mlir::LLVM::LLVMPointerType llvmPointerType = |
3040 | | - mlir::LLVM::LLVMPointerType::get(ctx); |
3041 | | - mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); |
3042 | | - mlir::Type llvmIntPtrType = |
3043 | | - mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); |
3044 | | - auto fctTy = mlir::LLVM::LLVMFunctionType::get( |
3045 | | - llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); |
3046 | | - |
3047 | | - auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( |
3048 | | - RTNAME_STRING(CUFAllocDesciptor)); |
3049 | | - auto funcFunc = |
3050 | | - mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor)); |
3051 | | - if (!llvmFunc && !funcFunc) |
3052 | | - mlir::OpBuilder::atBlockEnd(mod.getBody()) |
3053 | | - .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor), |
3054 | | - fctTy); |
3055 | | - |
3056 | | - mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
3057 | | - std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; |
3058 | | - mlir::Value sizeInBytes = |
3059 | | - genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); |
3060 | | - llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; |
3061 | | - return rewriter |
3062 | | - .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), |
3063 | | - args) |
3064 | | - .getResult(); |
3065 | | -} |
3066 | | - |
3067 | 3088 | /// `fir.load` --> `llvm.load` |
3068 | 3089 | struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> { |
3069 | 3090 | using FIROpConversion::FIROpConversion; |
|
0 commit comments