Skip to content

Commit b776881

Browse files
authored
[CIR][CUDA] Register __global__ functions (#1441)
This is part 2 of CUDA lowering. Still more to come! This PR generates `__cuda_register_globals` for functions only, without touching variables. It also fixes two discrepancies mentioned in Part 1, namely: - Now CIR will not generate registration code if there's nothing to register; - `__cuda_fatbin_wrapper` now becomes a constant.
1 parent 75914ec commit b776881

File tree

4 files changed

+254
-64
lines changed

4 files changed

+254
-64
lines changed

clang/include/clang/CIR/Dialect/IR/CIRDataLayout.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/DLTI/DLTI.h"
1616
#include "mlir/IR/BuiltinOps.h"
17+
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
1718
#include "clang/CIR/Dialect/IR/CIRTypes.h"
1819
#include "llvm/IR/DataLayout.h"
1920
#include "llvm/Support/Alignment.h"
@@ -35,6 +36,8 @@ class CIRDataLayout {
3536
// The StructType -> StructLayout map.
3637
mutable void *LayoutMap = nullptr;
3738

39+
TypeSizeInfoAttr typeSizeInfo;
40+
3841
public:
3942
mlir::DataLayout layout;
4043

@@ -106,6 +109,14 @@ class CIRDataLayout {
106109
cir::IntType::get(Ty.getContext(), getPointerTypeSizeInBits(Ty), false);
107110
return IntTy;
108111
}
112+
113+
mlir::Type getIntType(mlir::MLIRContext *ctx) const {
114+
return typeSizeInfo.getIntType(ctx);
115+
}
116+
117+
mlir::Type getCharType(mlir::MLIRContext *ctx) const {
118+
return typeSizeInfo.getCharType(ctx);
119+
}
109120
};
110121

111122
/// Used to lazily calculate structure layout information for a target machine,

clang/lib/CIR/Dialect/IR/CIRDataLayout.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
2+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
23
#include "clang/CIR/Dialect/IR/CIRTypes.h"
34
#include "clang/CIR/MissingFeatures.h"
45
#include "llvm/IR/DataLayout.h"
@@ -112,6 +113,17 @@ class StructLayoutMap {
112113

113114
CIRDataLayout::CIRDataLayout(mlir::ModuleOp modOp) : layout{modOp} {
114115
reset(modOp.getDataLayoutSpec());
116+
if (auto attr = modOp->getAttr(cir::CIRDialect::getTypeSizeInfoAttrName()))
117+
typeSizeInfo = mlir::cast<TypeSizeInfoAttr>(attr);
118+
else {
119+
// Generate default size information.
120+
auto voidPtrTy = PointerType::get(VoidType::get(modOp->getContext()));
121+
llvm::TypeSize ptrSize = getTypeSizeInBits(voidPtrTy);
122+
typeSizeInfo =
123+
TypeSizeInfoAttr::get(modOp->getContext(),
124+
/*char_size=*/8, /*int_size=*/32,
125+
/*size_t_size=*/ptrSize.getFixedValue());
126+
}
115127
}
116128

117129
void CIRDataLayout::reset(mlir::DataLayoutSpecInterface spec) {

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 177 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,16 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
123123
/// CUDA related
124124
/// ------------
125125

126-
// Maps CUDA device stub name to kernel name.
127-
llvm::DenseMap<llvm::StringRef, std::string> cudaKernelMap;
126+
// Maps CUDA kernel name to device stub function.
127+
llvm::StringMap<FuncOp> cudaKernelMap;
128128

129129
void buildCUDAModuleCtor();
130130
void buildCUDAModuleDtor();
131131
std::optional<FuncOp> buildCUDARegisterGlobals();
132132

133+
void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
134+
FuncOp regGlobalFunc);
135+
133136
///
134137
/// AST related
135138
/// -----------
@@ -185,6 +188,18 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
185188
/// List of annotations in the module
186189
llvm::SmallVector<mlir::Attribute, 4> globalAnnotations;
187190
};
191+
192+
std::string getCUDAPrefix(clang::ASTContext *astCtx) {
193+
if (astCtx->getLangOpts().HIP)
194+
return "hip";
195+
return "cuda";
196+
}
197+
198+
std::string addUnderscoredPrefix(llvm::StringRef cudaPrefix,
199+
llvm::StringRef cudaFunctionName) {
200+
return ("__" + cudaPrefix + cudaFunctionName).str();
201+
}
202+
188203
} // namespace
189204

190205
GlobalOp LoweringPreparePass::buildRuntimeVariable(
@@ -983,6 +998,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
983998
if (astCtx->getLangOpts().GPURelocatableDeviceCode)
984999
llvm_unreachable("NYI");
9851000

1001+
// For CUDA without -fgpu-rdc, it's safe to stop generating ctor
1002+
// if there's nothing to register.
1003+
if (cudaKernelMap.empty())
1004+
return;
1005+
9861006
// There's no device-side binary, so no need to proceed for CUDA.
9871007
// HIP has to create an external symbol in this case, which is NYI.
9881008
auto cudaBinaryHandleAttr =
@@ -995,18 +1015,14 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
9951015
std::string cudaGPUBinaryName =
9961016
cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr).getName();
9971017

998-
llvm::StringRef prefix = "cuda";
999-
10001018
constexpr unsigned cudaFatMagic = 0x466243b1;
10011019
constexpr unsigned hipFatMagic = 0x48495046; // "HIPF"
10021020

1021+
auto cudaPrefix = getCUDAPrefix(astCtx);
1022+
10031023
const unsigned fatMagic =
10041024
astCtx->getLangOpts().HIP ? hipFatMagic : cudaFatMagic;
10051025

1006-
auto addUnderscoredPrefix = [&](llvm::StringRef name) -> std::string {
1007-
return ("__" + prefix + name).str();
1008-
};
1009-
10101026
// MAC OS X needs special care, but we haven't supported that in CIR yet.
10111027
assert(!cir::MissingFeatures::checkMacOSXTriple());
10121028

@@ -1015,15 +1031,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10151031

10161032
mlir::Location loc = theModule.getLoc();
10171033

1018-
// Extract types from the module.
1019-
auto typeSizesAttr = cast<TypeSizeInfoAttr>(
1020-
theModule->getAttr(CIRDialect::getTypeSizeInfoAttrName()));
1021-
10221034
auto voidTy = VoidType::get(&getContext());
10231035
auto voidPtrTy = PointerType::get(voidTy);
10241036
auto voidPtrPtrTy = PointerType::get(voidPtrTy);
1025-
auto intTy = typeSizesAttr.getIntType(&getContext());
1026-
auto charTy = typeSizesAttr.getCharType(&getContext());
1037+
auto intTy = datalayout->getIntType(&getContext());
1038+
auto charTy = datalayout->getCharType(&getContext());
10271039

10281040
// Read the GPU binary and create a constant array for it.
10291041
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cudaGPUBinaryOrErr =
@@ -1046,7 +1058,7 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10461058

10471059
// OG gives an empty name to this global constant,
10481060
// which is not allowed in CIR.
1049-
std::string fatbinStrName = addUnderscoredPrefix("_fatbin_str");
1061+
std::string fatbinStrName = addUnderscoredPrefix(cudaPrefix, "_fatbin_str");
10501062
GlobalOp fatbinStr = builder.create<GlobalOp>(
10511063
loc, fatbinStrName, fatbinType, /*isConstant=*/true,
10521064
/*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
@@ -1064,59 +1076,186 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10641076
&getContext(), {intTy, intTy, voidPtrTy, voidPtrTy}, /*packed=*/false,
10651077
/*padded=*/false, StructType::RecordKind::Struct);
10661078

1067-
std::string fatbinWrapperName = addUnderscoredPrefix("_fatbin_wrapper");
1079+
std::string fatbinWrapperName =
1080+
addUnderscoredPrefix(cudaPrefix, "_fatbin_wrapper");
10681081
GlobalOp fatbinWrapper = builder.create<GlobalOp>(
1069-
loc, fatbinWrapperName, fatbinWrapperType, /*isConstant=*/false,
1082+
loc, fatbinWrapperName, fatbinWrapperType, /*isConstant=*/true,
10701083
/*linkage=*/cir::GlobalLinkageKind::InternalLinkage);
10711084
fatbinWrapper.setPrivate();
10721085
fatbinWrapper.setSection(fatbinSectionName);
10731086

10741087
auto magicInit = IntAttr::get(intTy, fatMagic);
10751088
auto versionInit = IntAttr::get(intTy, 1);
1076-
// `fatbinInit` is only a placeholder. The value will be initialized at the
1077-
// beginning of module ctor.
1078-
auto fatbinInit = builder.getConstNullPtrAttr(voidPtrTy);
1089+
auto fatbinStrSymbol =
1090+
mlir::FlatSymbolRefAttr::get(fatbinStr.getSymNameAttr());
1091+
auto fatbinInit = GlobalViewAttr::get(voidPtrTy, fatbinStrSymbol);
10791092
auto unusedInit = builder.getConstNullPtrAttr(voidPtrTy);
10801093
fatbinWrapper.setInitialValueAttr(cir::ConstStructAttr::get(
10811094
fatbinWrapperType,
10821095
ArrayAttr::get(&getContext(),
10831096
{magicInit, versionInit, fatbinInit, unusedInit})));
10841097

1098+
// GPU fat binary handle is also a global variable in OG.
1099+
std::string gpubinHandleName =
1100+
addUnderscoredPrefix(cudaPrefix, "_gpubin_handle");
1101+
auto gpubinHandle = builder.create<GlobalOp>(
1102+
loc, gpubinHandleName, voidPtrPtrTy,
1103+
/*isConstant=*/false, /*linkage=*/GlobalLinkageKind::InternalLinkage);
1104+
gpubinHandle.setInitialValueAttr(builder.getConstNullPtrAttr(voidPtrPtrTy));
1105+
gpubinHandle.setPrivate();
1106+
10851107
// Declare this function:
10861108
// void **__{cuda|hip}RegisterFatBinary(void *);
10871109

1088-
std::string regFuncName = addUnderscoredPrefix("RegisterFatBinary");
1110+
std::string regFuncName =
1111+
addUnderscoredPrefix(cudaPrefix, "RegisterFatBinary");
10891112
auto regFuncType = FuncType::get({voidPtrTy}, voidPtrPtrTy);
10901113
auto regFunc = buildRuntimeFunction(builder, regFuncName, loc, regFuncType);
10911114

10921115
// Create the module constructor.
10931116

1094-
std::string moduleCtorName = addUnderscoredPrefix("_module_ctor");
1117+
std::string moduleCtorName = addUnderscoredPrefix(cudaPrefix, "_module_ctor");
10951118
auto moduleCtor = buildRuntimeFunction(builder, moduleCtorName, loc,
10961119
FuncType::get({}, voidTy),
10971120
GlobalLinkageKind::InternalLinkage);
10981121
globalCtorList.push_back(GlobalCtorAttr::get(&getContext(), moduleCtorName));
10991122
builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
11001123

1101-
auto wrapper = builder.createGetGlobal(fatbinWrapper);
1102-
// Put fatbinStr inside fatbinWrapper.
1103-
mlir::Value fatbinStrValue = builder.createGetGlobal(fatbinStr);
1104-
mlir::Value fatbinField = builder.createGetMemberOp(loc, wrapper, "", 2);
1105-
builder.createStore(loc, fatbinStrValue, fatbinField);
1106-
11071124
// Register binary with CUDA runtime. This is substantially different in
11081125
// default mode vs. separate compilation.
11091126
// Corresponding code:
11101127
// gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
1128+
auto wrapper = builder.createGetGlobal(fatbinWrapper);
11111129
auto fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
1112-
builder.createCallOp(loc, regFunc, fatbinVoidPtr);
1130+
auto gpuBinaryHandleCall = builder.createCallOp(loc, regFunc, fatbinVoidPtr);
1131+
auto gpuBinaryHandle = gpuBinaryHandleCall.getResult();
1132+
// Store the value back to the global `__cuda_gpubin_handle`.
1133+
auto gpuBinaryHandleGlobal = builder.createGetGlobal(gpubinHandle);
1134+
builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
1135+
1136+
// Generate __cuda_register_globals and call it.
1137+
std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals();
1138+
if (regGlobal) {
1139+
builder.createCallOp(loc, *regGlobal, gpuBinaryHandle);
1140+
}
11131141

1114-
// This is currently incomplete.
1115-
// TODO(cir): create __cuda_register_globals(), and call it here.
1142+
// From CUDA 10.1 onwards, we must call this function to end registration:
1143+
// void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
1144+
// This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
1145+
if (clang::CudaFeatureEnabled(
1146+
astCtx->getTargetInfo().getSDKVersion(),
1147+
clang::CudaFeature::CUDA_USES_FATBIN_REGISTER_END)) {
1148+
cir::CIRBaseBuilderTy globalBuilder(getContext());
1149+
globalBuilder.setInsertionPointToStart(theModule.getBody());
1150+
FuncOp endFunc =
1151+
buildRuntimeFunction(globalBuilder, "__cudaRegisterFatBinaryEnd", loc,
1152+
FuncType::get({voidPtrPtrTy}, voidTy));
1153+
builder.createCallOp(loc, endFunc, gpuBinaryHandle);
1154+
}
11161155

11171156
builder.create<cir::ReturnOp>(loc);
11181157
}
11191158

1159+
std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
1160+
// There is nothing to register.
1161+
if (cudaKernelMap.empty())
1162+
return {};
1163+
1164+
cir::CIRBaseBuilderTy builder(getContext());
1165+
builder.setInsertionPointToStart(theModule.getBody());
1166+
1167+
auto loc = theModule.getLoc();
1168+
auto cudaPrefix = getCUDAPrefix(astCtx);
1169+
1170+
auto voidTy = VoidType::get(&getContext());
1171+
auto voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
1172+
1173+
// Create the function:
1174+
// void __cuda_register_globals(void **fatbinHandle)
1175+
std::string regGlobalFuncName =
1176+
addUnderscoredPrefix(cudaPrefix, "_register_globals");
1177+
auto regGlobalFuncTy = FuncType::get({voidPtrPtrTy}, voidTy);
1178+
FuncOp regGlobalFunc =
1179+
buildRuntimeFunction(builder, regGlobalFuncName, loc, regGlobalFuncTy,
1180+
/*linkage=*/GlobalLinkageKind::InternalLinkage);
1181+
builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
1182+
1183+
buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
1184+
1185+
// TODO(cir): registration for global variables.
1186+
1187+
builder.create<ReturnOp>(loc);
1188+
return regGlobalFunc;
1189+
}
1190+
1191+
void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
1192+
cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
1193+
auto loc = theModule.getLoc();
1194+
auto cudaPrefix = getCUDAPrefix(astCtx);
1195+
1196+
auto voidTy = VoidType::get(&getContext());
1197+
auto voidPtrTy = PointerType::get(voidTy);
1198+
auto voidPtrPtrTy = PointerType::get(voidPtrTy);
1199+
auto intTy = datalayout->getIntType(&getContext());
1200+
auto charTy = datalayout->getCharType(&getContext());
1201+
1202+
// Extract the GPU binary handle argument.
1203+
mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
1204+
1205+
cir::CIRBaseBuilderTy globalBuilder(getContext());
1206+
globalBuilder.setInsertionPointToStart(theModule.getBody());
1207+
1208+
// Declare CUDA internal functions:
1209+
// int __cudaRegisterFunction(
1210+
// void **fatbinHandle,
1211+
// const char *hostFunc,
1212+
// char *deviceFunc,
1213+
// const char *deviceName,
1214+
// int threadLimit,
1215+
// uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
1216+
// int *wsize
1217+
// )
1218+
// OG doesn't care about the types at all. They're treated as void*.
1219+
1220+
FuncOp cudaRegisterFunction = buildRuntimeFunction(
1221+
globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterFunction"), loc,
1222+
FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
1223+
voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
1224+
intTy));
1225+
1226+
auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
1227+
auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
1228+
1229+
auto tmpString = globalBuilder.create<GlobalOp>(
1230+
loc, (".str" + str).str(), strType, /*isConstant=*/true,
1231+
/*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
1232+
1233+
// We must make the string zero-terminated.
1234+
tmpString.setInitialValueAttr(ConstArrayAttr::get(
1235+
strType, StringAttr::get(&getContext(), str + "\0")));
1236+
tmpString.setPrivate();
1237+
return tmpString;
1238+
};
1239+
1240+
auto cirNullPtr = builder.getNullPtr(voidPtrTy, loc);
1241+
for (auto kernelName : cudaKernelMap.keys()) {
1242+
FuncOp deviceStub = cudaKernelMap[kernelName];
1243+
GlobalOp deviceFuncStr = makeConstantString(kernelName);
1244+
mlir::Value deviceFunc = builder.createBitcast(
1245+
builder.createGetGlobal(deviceFuncStr), voidPtrTy);
1246+
mlir::Value hostFunc = builder.createBitcast(
1247+
builder.create<GetGlobalOp>(
1248+
loc, PointerType::get(deviceStub.getFunctionType()),
1249+
mlir::FlatSymbolRefAttr::get(deviceStub.getSymNameAttr())),
1250+
voidPtrTy);
1251+
builder.createCallOp(
1252+
loc, cudaRegisterFunction,
1253+
{fatbinHandle, hostFunc, deviceFunc, deviceFunc,
1254+
builder.create<ConstantOp>(loc, IntAttr::get(intTy, -1)), cirNullPtr,
1255+
cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
1256+
}
1257+
}
1258+
11201259
void LoweringPreparePass::lowerDynamicCastOp(DynamicCastOp op) {
11211260
CIRBaseBuilderTy builder(getContext());
11221261
builder.setInsertionPointAfter(op);
@@ -1378,11 +1517,10 @@ void LoweringPreparePass::runOnOp(Operation *op) {
13781517
globalDtorList.push_back(globalDtor);
13791518
}
13801519
if (auto attr = fnOp.getExtraAttrs().getElements().get(
1381-
CIRDialect::getCUDABinaryHandleAttrName())) {
1382-
auto cudaBinaryAttr = dyn_cast<CUDABinaryHandleAttr>(attr);
1383-
std::string kernelName = cudaBinaryAttr.getName();
1384-
llvm::StringRef stubName = fnOp.getSymName();
1385-
cudaKernelMap[stubName] = kernelName;
1520+
CUDAKernelNameAttr::getMnemonic())) {
1521+
auto cudaBinaryAttr = dyn_cast<CUDAKernelNameAttr>(attr);
1522+
std::string kernelName = cudaBinaryAttr.getKernelName();
1523+
cudaKernelMap[kernelName] = fnOp;
13861524
}
13871525
if (std::optional<mlir::ArrayAttr> annotations = fnOp.getAnnotations())
13881526
addGlobalAnnotations(fnOp, annotations.value());
@@ -1399,6 +1537,9 @@ void LoweringPreparePass::runOnOperation() {
13991537
datalayout.emplace(theModule);
14001538
}
14011539

1540+
auto typeSizeInfo = cast<TypeSizeInfoAttr>(
1541+
theModule->getAttr(CIRDialect::getTypeSizeInfoAttrName()));
1542+
14021543
llvm::SmallVector<Operation *> opsToTransform;
14031544

14041545
op->walk([&](Operation *op) {

0 commit comments

Comments
 (0)