@@ -123,13 +123,16 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
123
123
// / CUDA related
124
124
// / ------------
125
125
126
- // Maps CUDA device stub name to kernel name .
127
- llvm::DenseMap<llvm::StringRef, std::string > cudaKernelMap;
126
+ // Maps CUDA kernel name to device stub function .
127
+ llvm::StringMap<FuncOp > cudaKernelMap;
128
128
129
129
void buildCUDAModuleCtor ();
130
130
void buildCUDAModuleDtor ();
131
131
std::optional<FuncOp> buildCUDARegisterGlobals ();
132
132
133
+ void buildCUDARegisterGlobalFunctions (cir::CIRBaseBuilderTy &builder,
134
+ FuncOp regGlobalFunc);
135
+
133
136
// /
134
137
// / AST related
135
138
// / -----------
@@ -185,6 +188,18 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
185
188
// / List of annotations in the module
186
189
llvm::SmallVector<mlir::Attribute, 4 > globalAnnotations;
187
190
};
191
+
192
+ std::string getCUDAPrefix (clang::ASTContext *astCtx) {
193
+ if (astCtx->getLangOpts ().HIP )
194
+ return " hip" ;
195
+ return " cuda" ;
196
+ }
197
+
198
+ std::string addUnderscoredPrefix (llvm::StringRef cudaPrefix,
199
+ llvm::StringRef cudaFunctionName) {
200
+ return (" __" + cudaPrefix + cudaFunctionName).str ();
201
+ }
202
+
188
203
} // namespace
189
204
190
205
GlobalOp LoweringPreparePass::buildRuntimeVariable (
@@ -983,6 +998,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
983
998
if (astCtx->getLangOpts ().GPURelocatableDeviceCode )
984
999
llvm_unreachable (" NYI" );
985
1000
1001
+ // For CUDA without -fgpu-rdc, it's safe to stop generating ctor
1002
+ // if there's nothing to register.
1003
+ if (cudaKernelMap.empty ())
1004
+ return ;
1005
+
986
1006
// There's no device-side binary, so no need to proceed for CUDA.
987
1007
// HIP has to create an external symbol in this case, which is NYI.
988
1008
auto cudaBinaryHandleAttr =
@@ -995,18 +1015,14 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
995
1015
std::string cudaGPUBinaryName =
996
1016
cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr).getName ();
997
1017
998
- llvm::StringRef prefix = " cuda" ;
999
-
1000
1018
constexpr unsigned cudaFatMagic = 0x466243b1 ;
1001
1019
constexpr unsigned hipFatMagic = 0x48495046 ; // "HIPF"
1002
1020
1021
+ auto cudaPrefix = getCUDAPrefix (astCtx);
1022
+
1003
1023
const unsigned fatMagic =
1004
1024
astCtx->getLangOpts ().HIP ? hipFatMagic : cudaFatMagic;
1005
1025
1006
- auto addUnderscoredPrefix = [&](llvm::StringRef name) -> std::string {
1007
- return (" __" + prefix + name).str ();
1008
- };
1009
-
1010
1026
// MAC OS X needs special care, but we haven't supported that in CIR yet.
1011
1027
assert (!cir::MissingFeatures::checkMacOSXTriple ());
1012
1028
@@ -1015,15 +1031,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
1015
1031
1016
1032
mlir::Location loc = theModule.getLoc ();
1017
1033
1018
- // Extract types from the module.
1019
- auto typeSizesAttr = cast<TypeSizeInfoAttr>(
1020
- theModule->getAttr (CIRDialect::getTypeSizeInfoAttrName ()));
1021
-
1022
1034
auto voidTy = VoidType::get (&getContext ());
1023
1035
auto voidPtrTy = PointerType::get (voidTy);
1024
1036
auto voidPtrPtrTy = PointerType::get (voidPtrTy);
1025
- auto intTy = typeSizesAttr. getIntType (&getContext ());
1026
- auto charTy = typeSizesAttr. getCharType (&getContext ());
1037
+ auto intTy = datalayout-> getIntType (&getContext ());
1038
+ auto charTy = datalayout-> getCharType (&getContext ());
1027
1039
1028
1040
// Read the GPU binary and create a constant array for it.
1029
1041
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cudaGPUBinaryOrErr =
@@ -1046,7 +1058,7 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
1046
1058
1047
1059
// OG gives an empty name to this global constant,
1048
1060
// which is not allowed in CIR.
1049
- std::string fatbinStrName = addUnderscoredPrefix (" _fatbin_str" );
1061
+ std::string fatbinStrName = addUnderscoredPrefix (cudaPrefix, " _fatbin_str" );
1050
1062
GlobalOp fatbinStr = builder.create <GlobalOp>(
1051
1063
loc, fatbinStrName, fatbinType, /* isConstant=*/ true ,
1052
1064
/* linkage=*/ cir::GlobalLinkageKind::PrivateLinkage);
@@ -1064,59 +1076,186 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
1064
1076
&getContext (), {intTy, intTy, voidPtrTy, voidPtrTy}, /* packed=*/ false ,
1065
1077
/* padded=*/ false , StructType::RecordKind::Struct);
1066
1078
1067
- std::string fatbinWrapperName = addUnderscoredPrefix (" _fatbin_wrapper" );
1079
+ std::string fatbinWrapperName =
1080
+ addUnderscoredPrefix (cudaPrefix, " _fatbin_wrapper" );
1068
1081
GlobalOp fatbinWrapper = builder.create <GlobalOp>(
1069
- loc, fatbinWrapperName, fatbinWrapperType, /* isConstant=*/ false ,
1082
+ loc, fatbinWrapperName, fatbinWrapperType, /* isConstant=*/ true ,
1070
1083
/* linkage=*/ cir::GlobalLinkageKind::InternalLinkage);
1071
1084
fatbinWrapper.setPrivate ();
1072
1085
fatbinWrapper.setSection (fatbinSectionName);
1073
1086
1074
1087
auto magicInit = IntAttr::get (intTy, fatMagic);
1075
1088
auto versionInit = IntAttr::get (intTy, 1 );
1076
- // `fatbinInit` is only a placeholder. The value will be initialized at the
1077
- // beginning of module ctor.
1078
- auto fatbinInit = builder. getConstNullPtrAttr (voidPtrTy);
1089
+ auto fatbinStrSymbol =
1090
+ mlir::FlatSymbolRefAttr::get (fatbinStr. getSymNameAttr ());
1091
+ auto fatbinInit = GlobalViewAttr::get (voidPtrTy, fatbinStrSymbol );
1079
1092
auto unusedInit = builder.getConstNullPtrAttr (voidPtrTy);
1080
1093
fatbinWrapper.setInitialValueAttr (cir::ConstStructAttr::get (
1081
1094
fatbinWrapperType,
1082
1095
ArrayAttr::get (&getContext (),
1083
1096
{magicInit, versionInit, fatbinInit, unusedInit})));
1084
1097
1098
+ // GPU fat binary handle is also a global variable in OG.
1099
+ std::string gpubinHandleName =
1100
+ addUnderscoredPrefix (cudaPrefix, " _gpubin_handle" );
1101
+ auto gpubinHandle = builder.create <GlobalOp>(
1102
+ loc, gpubinHandleName, voidPtrPtrTy,
1103
+ /* isConstant=*/ false , /* linkage=*/ GlobalLinkageKind::InternalLinkage);
1104
+ gpubinHandle.setInitialValueAttr (builder.getConstNullPtrAttr (voidPtrPtrTy));
1105
+ gpubinHandle.setPrivate ();
1106
+
1085
1107
// Declare this function:
1086
1108
// void **__{cuda|hip}RegisterFatBinary(void *);
1087
1109
1088
- std::string regFuncName = addUnderscoredPrefix (" RegisterFatBinary" );
1110
+ std::string regFuncName =
1111
+ addUnderscoredPrefix (cudaPrefix, " RegisterFatBinary" );
1089
1112
auto regFuncType = FuncType::get ({voidPtrTy}, voidPtrPtrTy);
1090
1113
auto regFunc = buildRuntimeFunction (builder, regFuncName, loc, regFuncType);
1091
1114
1092
1115
// Create the module constructor.
1093
1116
1094
- std::string moduleCtorName = addUnderscoredPrefix (" _module_ctor" );
1117
+ std::string moduleCtorName = addUnderscoredPrefix (cudaPrefix, " _module_ctor" );
1095
1118
auto moduleCtor = buildRuntimeFunction (builder, moduleCtorName, loc,
1096
1119
FuncType::get ({}, voidTy),
1097
1120
GlobalLinkageKind::InternalLinkage);
1098
1121
globalCtorList.push_back (GlobalCtorAttr::get (&getContext (), moduleCtorName));
1099
1122
builder.setInsertionPointToStart (moduleCtor.addEntryBlock ());
1100
1123
1101
- auto wrapper = builder.createGetGlobal (fatbinWrapper);
1102
- // Put fatbinStr inside fatbinWrapper.
1103
- mlir::Value fatbinStrValue = builder.createGetGlobal (fatbinStr);
1104
- mlir::Value fatbinField = builder.createGetMemberOp (loc, wrapper, " " , 2 );
1105
- builder.createStore (loc, fatbinStrValue, fatbinField);
1106
-
1107
1124
// Register binary with CUDA runtime. This is substantially different in
1108
1125
// default mode vs. separate compilation.
1109
1126
// Corresponding code:
1110
1127
// gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
1128
+ auto wrapper = builder.createGetGlobal (fatbinWrapper);
1111
1129
auto fatbinVoidPtr = builder.createBitcast (wrapper, voidPtrTy);
1112
- builder.createCallOp (loc, regFunc, fatbinVoidPtr);
1130
+ auto gpuBinaryHandleCall = builder.createCallOp (loc, regFunc, fatbinVoidPtr);
1131
+ auto gpuBinaryHandle = gpuBinaryHandleCall.getResult ();
1132
+ // Store the value back to the global `__cuda_gpubin_handle`.
1133
+ auto gpuBinaryHandleGlobal = builder.createGetGlobal (gpubinHandle);
1134
+ builder.createStore (loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
1135
+
1136
+ // Generate __cuda_register_globals and call it.
1137
+ std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals ();
1138
+ if (regGlobal) {
1139
+ builder.createCallOp (loc, *regGlobal, gpuBinaryHandle);
1140
+ }
1113
1141
1114
- // This is currently incomplete.
1115
- // TODO(cir): create __cuda_register_globals(), and call it here.
1142
+ // From CUDA 10.1 onwards, we must call this function to end registration:
1143
+ // void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
1144
+ // This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
1145
+ if (clang::CudaFeatureEnabled (
1146
+ astCtx->getTargetInfo ().getSDKVersion (),
1147
+ clang::CudaFeature::CUDA_USES_FATBIN_REGISTER_END)) {
1148
+ cir::CIRBaseBuilderTy globalBuilder (getContext ());
1149
+ globalBuilder.setInsertionPointToStart (theModule.getBody ());
1150
+ FuncOp endFunc =
1151
+ buildRuntimeFunction (globalBuilder, " __cudaRegisterFatBinaryEnd" , loc,
1152
+ FuncType::get ({voidPtrPtrTy}, voidTy));
1153
+ builder.createCallOp (loc, endFunc, gpuBinaryHandle);
1154
+ }
1116
1155
1117
1156
builder.create <cir::ReturnOp>(loc);
1118
1157
}
1119
1158
1159
+ std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals () {
1160
+ // There is nothing to register.
1161
+ if (cudaKernelMap.empty ())
1162
+ return {};
1163
+
1164
+ cir::CIRBaseBuilderTy builder (getContext ());
1165
+ builder.setInsertionPointToStart (theModule.getBody ());
1166
+
1167
+ auto loc = theModule.getLoc ();
1168
+ auto cudaPrefix = getCUDAPrefix (astCtx);
1169
+
1170
+ auto voidTy = VoidType::get (&getContext ());
1171
+ auto voidPtrPtrTy = PointerType::get (PointerType::get (voidTy));
1172
+
1173
+ // Create the function:
1174
+ // void __cuda_register_globals(void **fatbinHandle)
1175
+ std::string regGlobalFuncName =
1176
+ addUnderscoredPrefix (cudaPrefix, " _register_globals" );
1177
+ auto regGlobalFuncTy = FuncType::get ({voidPtrPtrTy}, voidTy);
1178
+ FuncOp regGlobalFunc =
1179
+ buildRuntimeFunction (builder, regGlobalFuncName, loc, regGlobalFuncTy,
1180
+ /* linkage=*/ GlobalLinkageKind::InternalLinkage);
1181
+ builder.setInsertionPointToStart (regGlobalFunc.addEntryBlock ());
1182
+
1183
+ buildCUDARegisterGlobalFunctions (builder, regGlobalFunc);
1184
+
1185
+ // TODO(cir): registration for global variables.
1186
+
1187
+ builder.create <ReturnOp>(loc);
1188
+ return regGlobalFunc;
1189
+ }
1190
+
1191
+ void LoweringPreparePass::buildCUDARegisterGlobalFunctions (
1192
+ cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
1193
+ auto loc = theModule.getLoc ();
1194
+ auto cudaPrefix = getCUDAPrefix (astCtx);
1195
+
1196
+ auto voidTy = VoidType::get (&getContext ());
1197
+ auto voidPtrTy = PointerType::get (voidTy);
1198
+ auto voidPtrPtrTy = PointerType::get (voidPtrTy);
1199
+ auto intTy = datalayout->getIntType (&getContext ());
1200
+ auto charTy = datalayout->getCharType (&getContext ());
1201
+
1202
+ // Extract the GPU binary handle argument.
1203
+ mlir::Value fatbinHandle = *regGlobalFunc.args_begin ();
1204
+
1205
+ cir::CIRBaseBuilderTy globalBuilder (getContext ());
1206
+ globalBuilder.setInsertionPointToStart (theModule.getBody ());
1207
+
1208
+ // Declare CUDA internal functions:
1209
+ // int __cudaRegisterFunction(
1210
+ // void **fatbinHandle,
1211
+ // const char *hostFunc,
1212
+ // char *deviceFunc,
1213
+ // const char *deviceName,
1214
+ // int threadLimit,
1215
+ // uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
1216
+ // int *wsize
1217
+ // )
1218
+ // OG doesn't care about the types at all. They're treated as void*.
1219
+
1220
+ FuncOp cudaRegisterFunction = buildRuntimeFunction (
1221
+ globalBuilder, addUnderscoredPrefix (cudaPrefix, " RegisterFunction" ), loc,
1222
+ FuncType::get ({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
1223
+ voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
1224
+ intTy));
1225
+
1226
+ auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
1227
+ auto strType = ArrayType::get (&getContext (), charTy, 1 + str.size ());
1228
+
1229
+ auto tmpString = globalBuilder.create <GlobalOp>(
1230
+ loc, (" .str" + str).str (), strType, /* isConstant=*/ true ,
1231
+ /* linkage=*/ cir::GlobalLinkageKind::PrivateLinkage);
1232
+
1233
+ // We must make the string zero-terminated.
1234
+ tmpString.setInitialValueAttr (ConstArrayAttr::get (
1235
+ strType, StringAttr::get (&getContext (), str + " \0 " )));
1236
+ tmpString.setPrivate ();
1237
+ return tmpString;
1238
+ };
1239
+
1240
+ auto cirNullPtr = builder.getNullPtr (voidPtrTy, loc);
1241
+ for (auto kernelName : cudaKernelMap.keys ()) {
1242
+ FuncOp deviceStub = cudaKernelMap[kernelName];
1243
+ GlobalOp deviceFuncStr = makeConstantString (kernelName);
1244
+ mlir::Value deviceFunc = builder.createBitcast (
1245
+ builder.createGetGlobal (deviceFuncStr), voidPtrTy);
1246
+ mlir::Value hostFunc = builder.createBitcast (
1247
+ builder.create <GetGlobalOp>(
1248
+ loc, PointerType::get (deviceStub.getFunctionType ()),
1249
+ mlir::FlatSymbolRefAttr::get (deviceStub.getSymNameAttr ())),
1250
+ voidPtrTy);
1251
+ builder.createCallOp (
1252
+ loc, cudaRegisterFunction,
1253
+ {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
1254
+ builder.create <ConstantOp>(loc, IntAttr::get (intTy, -1 )), cirNullPtr,
1255
+ cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
1256
+ }
1257
+ }
1258
+
1120
1259
void LoweringPreparePass::lowerDynamicCastOp (DynamicCastOp op) {
1121
1260
CIRBaseBuilderTy builder (getContext ());
1122
1261
builder.setInsertionPointAfter (op);
@@ -1378,11 +1517,10 @@ void LoweringPreparePass::runOnOp(Operation *op) {
1378
1517
globalDtorList.push_back (globalDtor);
1379
1518
}
1380
1519
if (auto attr = fnOp.getExtraAttrs ().getElements ().get (
1381
- CIRDialect::getCUDABinaryHandleAttrName ())) {
1382
- auto cudaBinaryAttr = dyn_cast<CUDABinaryHandleAttr>(attr);
1383
- std::string kernelName = cudaBinaryAttr.getName ();
1384
- llvm::StringRef stubName = fnOp.getSymName ();
1385
- cudaKernelMap[stubName] = kernelName;
1520
+ CUDAKernelNameAttr::getMnemonic ())) {
1521
+ auto cudaBinaryAttr = dyn_cast<CUDAKernelNameAttr>(attr);
1522
+ std::string kernelName = cudaBinaryAttr.getKernelName ();
1523
+ cudaKernelMap[kernelName] = fnOp;
1386
1524
}
1387
1525
if (std::optional<mlir::ArrayAttr> annotations = fnOp.getAnnotations ())
1388
1526
addGlobalAnnotations (fnOp, annotations.value ());
@@ -1399,6 +1537,9 @@ void LoweringPreparePass::runOnOperation() {
1399
1537
datalayout.emplace (theModule);
1400
1538
}
1401
1539
1540
+ auto typeSizeInfo = cast<TypeSizeInfoAttr>(
1541
+ theModule->getAttr (CIRDialect::getTypeSizeInfoAttrName ()));
1542
+
1402
1543
llvm::SmallVector<Operation *> opsToTransform;
1403
1544
1404
1545
op->walk ([&](Operation *op) {
0 commit comments