Skip to content

Commit 387177b

Browse files
committed
[flang][cuda] Allocate descriptor in managed memory when emboxing device memory
1 parent 99c2e3b commit 387177b

File tree

2 files changed

+147
-96
lines changed

2 files changed

+147
-96
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 117 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
#include "flang/Optimizer/Support/InternalNames.h"
2424
#include "flang/Optimizer/Support/TypeCode.h"
2525
#include "flang/Optimizer/Support/Utils.h"
26+
#include "flang/Optimizer/Transforms/CUFCommon.h"
2627
#include "flang/Runtime/CUDA/descriptor.h"
28+
#include "flang/Runtime/CUDA/memory.h"
2729
#include "flang/Runtime/allocator-registry-consts.h"
2830
#include "flang/Runtime/descriptor-consts.h"
2931
#include "flang/Semantics/runtime-type-info.h"
@@ -1135,6 +1137,93 @@ convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
11351137
return result;
11361138
}
11371139

1140+
static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
1141+
mlir::ConversionPatternRewriter &rewriter) {
1142+
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
1143+
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
1144+
auto fn = flc.getFilename().str() + '\0';
1145+
std::string globalName = fir::factory::uniqueCGIdent("cl", fn);
1146+
1147+
if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
1148+
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
1149+
} else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
1150+
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
1151+
}
1152+
1153+
auto crtInsPt = rewriter.saveInsertionPoint();
1154+
rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
1155+
auto arrayTy = mlir::LLVM::LLVMArrayType::get(
1156+
mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
1157+
mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
1158+
loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
1159+
globalName, mlir::Attribute());
1160+
1161+
mlir::Region &region = globalOp.getInitializerRegion();
1162+
mlir::Block *block = rewriter.createBlock(&region);
1163+
rewriter.setInsertionPoint(block, block->begin());
1164+
mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
1165+
loc, arrayTy, rewriter.getStringAttr(fn));
1166+
rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
1167+
rewriter.restoreInsertionPoint(crtInsPt);
1168+
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
1169+
globalOp.getName());
1170+
}
1171+
return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
1172+
}
1173+
1174+
static mlir::Value genSourceLine(mlir::Location loc,
1175+
mlir::ConversionPatternRewriter &rewriter) {
1176+
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
1177+
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
1178+
flc.getLine());
1179+
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
1180+
}
1181+
1182+
static mlir::Value
1183+
genCUFAllocDescriptor(mlir::Location loc,
1184+
mlir::ConversionPatternRewriter &rewriter,
1185+
mlir::ModuleOp mod, fir::BaseBoxType boxTy,
1186+
const fir::LLVMTypeConverter &typeConverter) {
1187+
std::optional<mlir::DataLayout> dl =
1188+
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
1189+
if (!dl)
1190+
mlir::emitError(mod.getLoc(),
1191+
"module operation must carry a data layout attribute "
1192+
"to generate llvm IR from FIR");
1193+
1194+
mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
1195+
mlir::Value sourceLine = genSourceLine(loc, rewriter);
1196+
1197+
mlir::MLIRContext *ctx = mod.getContext();
1198+
1199+
mlir::LLVM::LLVMPointerType llvmPointerType =
1200+
mlir::LLVM::LLVMPointerType::get(ctx);
1201+
mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
1202+
mlir::Type llvmIntPtrType =
1203+
mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
1204+
auto fctTy = mlir::LLVM::LLVMFunctionType::get(
1205+
llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});
1206+
1207+
auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
1208+
RTNAME_STRING(CUFAllocDesciptor));
1209+
auto funcFunc =
1210+
mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
1211+
if (!llvmFunc && !funcFunc)
1212+
mlir::OpBuilder::atBlockEnd(mod.getBody())
1213+
.create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
1214+
fctTy);
1215+
1216+
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
1217+
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
1218+
mlir::Value sizeInBytes =
1219+
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
1220+
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
1221+
return rewriter
1222+
.create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
1223+
args)
1224+
.getResult();
1225+
}
1226+
11381227
/// Common base class for embox to descriptor conversion.
11391228
template <typename OP>
11401229
struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
@@ -1548,15 +1637,24 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
15481637
mlir::Value
15491638
placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter,
15501639
mlir::Location loc, mlir::Type boxTy,
1551-
mlir::Value boxValue) const {
1640+
mlir::Value boxValue,
1641+
bool needDeviceAllocation = false) const {
15521642
if (isInGlobalOp(rewriter))
15531643
return boxValue;
15541644
mlir::Type llvmBoxTy = boxValue.getType();
1555-
auto alloca = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy,
1556-
defaultAlign, rewriter);
1557-
auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, alloca);
1645+
mlir::Value storage;
1646+
if (needDeviceAllocation) {
1647+
auto mod = boxValue.getDefiningOp()->getParentOfType<mlir::ModuleOp>();
1648+
auto baseBoxTy = mlir::dyn_cast<fir::BaseBoxType>(boxTy);
1649+
storage =
1650+
genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy());
1651+
} else {
1652+
storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign,
1653+
rewriter);
1654+
}
1655+
auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, storage);
15581656
this->attachTBAATag(storeOp, boxTy, boxTy, nullptr);
1559-
return alloca;
1657+
return storage;
15601658
}
15611659
};
15621660

@@ -1608,6 +1706,18 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
16081706
}
16091707
};
16101708

1709+
static bool isDeviceAllocation(mlir::Value val) {
1710+
if (auto convertOp =
1711+
mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
1712+
val = convertOp.getValue();
1713+
if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
1714+
if (callOp.getCallee() &&
1715+
callOp.getCallee().value().getRootReference().getValue().starts_with(
1716+
RTNAME_STRING(CUFMemAlloc)))
1717+
return true;
1718+
return false;
1719+
}
1720+
16111721
/// Create a generic box on a memory reference.
16121722
struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
16131723
using EmboxCommonConversion::EmboxCommonConversion;
@@ -1791,9 +1901,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
17911901
dest = insertBaseAddress(rewriter, loc, dest, base);
17921902
if (fir::isDerivedTypeWithLenParams(boxTy))
17931903
TODO(loc, "fir.embox codegen of derived with length parameters");
1794-
1795-
mlir::Value result =
1796-
placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest);
1904+
mlir::Value result = placeInMemoryIfNotGlobalInit(
1905+
rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref()));
17971906
rewriter.replaceOp(xbox, result);
17981907
return mlir::success();
17991908
}
@@ -2971,93 +3080,6 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
29713080
}
29723081
};
29733082

2974-
static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod,
2975-
mlir::ConversionPatternRewriter &rewriter) {
2976-
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
2977-
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) {
2978-
auto fn = flc.getFilename().str() + '\0';
2979-
std::string globalName = fir::factory::uniqueCGIdent("cl", fn);
2980-
2981-
if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) {
2982-
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
2983-
} else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) {
2984-
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName());
2985-
}
2986-
2987-
auto crtInsPt = rewriter.saveInsertionPoint();
2988-
rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
2989-
auto arrayTy = mlir::LLVM::LLVMArrayType::get(
2990-
mlir::IntegerType::get(rewriter.getContext(), 8), fn.size());
2991-
mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>(
2992-
loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce,
2993-
globalName, mlir::Attribute());
2994-
2995-
mlir::Region &region = globalOp.getInitializerRegion();
2996-
mlir::Block *block = rewriter.createBlock(&region);
2997-
rewriter.setInsertionPoint(block, block->begin());
2998-
mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>(
2999-
loc, arrayTy, rewriter.getStringAttr(fn));
3000-
rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue);
3001-
rewriter.restoreInsertionPoint(crtInsPt);
3002-
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy,
3003-
globalOp.getName());
3004-
}
3005-
return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy);
3006-
}
3007-
3008-
static mlir::Value genSourceLine(mlir::Location loc,
3009-
mlir::ConversionPatternRewriter &rewriter) {
3010-
if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
3011-
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(),
3012-
flc.getLine());
3013-
return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
3014-
}
3015-
3016-
static mlir::Value
3017-
genCUFAllocDescriptor(mlir::Location loc,
3018-
mlir::ConversionPatternRewriter &rewriter,
3019-
mlir::ModuleOp mod, fir::BaseBoxType boxTy,
3020-
const fir::LLVMTypeConverter &typeConverter) {
3021-
std::optional<mlir::DataLayout> dl =
3022-
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
3023-
if (!dl)
3024-
mlir::emitError(mod.getLoc(),
3025-
"module operation must carry a data layout attribute "
3026-
"to generate llvm IR from FIR");
3027-
3028-
mlir::Value sourceFile = genSourceFile(loc, mod, rewriter);
3029-
mlir::Value sourceLine = genSourceLine(loc, rewriter);
3030-
3031-
mlir::MLIRContext *ctx = mod.getContext();
3032-
3033-
mlir::LLVM::LLVMPointerType llvmPointerType =
3034-
mlir::LLVM::LLVMPointerType::get(ctx);
3035-
mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32);
3036-
mlir::Type llvmIntPtrType =
3037-
mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0));
3038-
auto fctTy = mlir::LLVM::LLVMFunctionType::get(
3039-
llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type});
3040-
3041-
auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
3042-
RTNAME_STRING(CUFAllocDesciptor));
3043-
auto funcFunc =
3044-
mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor));
3045-
if (!llvmFunc && !funcFunc)
3046-
mlir::OpBuilder::atBlockEnd(mod.getBody())
3047-
.create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor),
3048-
fctTy);
3049-
3050-
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
3051-
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
3052-
mlir::Value sizeInBytes =
3053-
genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize);
3054-
llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine};
3055-
return rewriter
3056-
.create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor),
3057-
args)
3058-
.getResult();
3059-
}
3060-
30613083
/// `fir.load` --> `llvm.load`
30623084
struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
30633085
using FIROpConversion::FIROpConversion;

flang/test/Fir/CUDA/cuda-code-gen.mlir

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
22

33
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
4-
54
func.func @_QQmain() attributes {fir.bindc_name = "cufkernel_global"} {
65
%c0 = arith.constant 0 : index
76
%0 = fir.address_of(@_QQclX3C737464696E3E00) : !fir.ref<!fir.char<1,8>>
@@ -27,3 +26,33 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> :
2726
}
2827
func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>> attributes {fir.runtime}
2928
}
29+
30+
// -----
31+
32+
module attributes {dlti.dl_spec = #dlti.dl_spec<f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>} {
33+
func.func @_QQmain() attributes {fir.bindc_name = "test"} {
34+
%c10 = arith.constant 10 : index
35+
%c20 = arith.constant 20 : index
36+
%0 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
37+
%c4 = arith.constant 4 : index
38+
%c200 = arith.constant 200 : index
39+
%1 = arith.muli %c200, %c4 : index
40+
%c6_i32 = arith.constant 6 : i32
41+
%c0_i32 = arith.constant 0 : i32
42+
%2 = fir.convert %1 : (index) -> i64
43+
%3 = fir.convert %0 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
44+
%4 = fir.call @_FortranACUFMemAlloc(%2, %c0_i32, %3, %c6_i32) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
45+
%5 = fir.convert %4 : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10x20xi32>>
46+
%6 = fircg.ext_embox %5(%c10, %c20) : (!fir.ref<!fir.array<10x20xi32>>, index, index) -> !fir.box<!fir.array<10x20xi32>>
47+
return
48+
}
49+
fir.global linkonce @_QQclX64756D6D792E6D6C697200 constant : !fir.char<1,11> {
50+
%0 = fir.string_lit "dummy.mlir\00"(11) : !fir.char<1,11>
51+
fir.has_value %0 : !fir.char<1,11>
52+
}
53+
func.func private @_FortranACUFMemAlloc(i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8> attributes {fir.runtime}
54+
}
55+
56+
// CHECK-LABEL: llvm.func @_QQmain()
57+
// CHECK: llvm.call @_FortranACUFMemAlloc
58+
// CHECK: llvm.call @_FortranACUFAllocDesciptor

0 commit comments

Comments
 (0)