@@ -457,45 +457,41 @@ void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
457457 IsHIP ? " .hip.fatbin_unreg" : " .cuda.fatbin_unreg" , &M);
458458 DtorFunc->setSection (" .text.startup" );
459459
460+ auto *PtrTy = PointerType::getUnqual (C);
461+
460462 // Get the __cudaRegisterFatBinary function declaration.
461- auto *RegFatTy = FunctionType::get (PointerType::getUnqual (C)->getPointerTo (),
462- PointerType::getUnqual (C),
463- /* isVarArg*/ false );
463+ auto *RegFatTy = FunctionType::get (PtrTy, PtrTy, /* isVarArg=*/ false );
464464 FunctionCallee RegFatbin = M.getOrInsertFunction (
465465 IsHIP ? " __hipRegisterFatBinary" : " __cudaRegisterFatBinary" , RegFatTy);
466466 // Get the __cudaRegisterFatBinaryEnd function declaration.
467- auto *RegFatEndTy = FunctionType::get (
468- Type::getVoidTy (C), PointerType::getUnqual (C)->getPointerTo (),
469- /* isVarArg*/ false );
467+ auto *RegFatEndTy =
468+ FunctionType::get (Type::getVoidTy (C), PtrTy, /* isVarArg=*/ false );
470469 FunctionCallee RegFatbinEnd =
471470 M.getOrInsertFunction (" __cudaRegisterFatBinaryEnd" , RegFatEndTy);
472471 // Get the __cudaUnregisterFatBinary function declaration.
473- auto *UnregFatTy = FunctionType::get (
474- Type::getVoidTy (C), PointerType::getUnqual (C)->getPointerTo (),
475- /* isVarArg*/ false );
472+ auto *UnregFatTy =
473+ FunctionType::get (Type::getVoidTy (C), PtrTy, /* isVarArg=*/ false );
476474 FunctionCallee UnregFatbin = M.getOrInsertFunction (
477475 IsHIP ? " __hipUnregisterFatBinary" : " __cudaUnregisterFatBinary" ,
478476 UnregFatTy);
479477
480478 auto *AtExitTy =
481- FunctionType::get (Type::getInt32Ty (C), DtorFuncTy->getPointerTo (),
482- /* isVarArg*/ false );
479+ FunctionType::get (Type::getInt32Ty (C), PtrTy, /* isVarArg=*/ false );
483480 FunctionCallee AtExit = M.getOrInsertFunction (" atexit" , AtExitTy);
484481
485482 auto *BinaryHandleGlobal = new llvm::GlobalVariable (
486- M, PointerType::getUnqual (C)->getPointerTo (), false ,
487- llvm::GlobalValue::InternalLinkage,
488- llvm::ConstantPointerNull::get (PointerType::getUnqual (C)->getPointerTo ()),
483+ M, PtrTy, false , llvm::GlobalValue::InternalLinkage,
484+ llvm::ConstantPointerNull::get (PtrTy),
489485 IsHIP ? " .hip.binary_handle" : " .cuda.binary_handle" );
490486
491487 // Create the constructor to register this image with the runtime.
492488 IRBuilder<> CtorBuilder (BasicBlock::Create (C, " entry" , CtorFunc));
493489 CallInst *Handle = CtorBuilder.CreateCall (
494- RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast (
495- FatbinDesc, PointerType::getUnqual (C) ));
490+ RegFatbin,
491+ ConstantExpr::getPointerBitCastOrAddrSpaceCast (FatbinDesc, PtrTy ));
496492 CtorBuilder.CreateAlignedStore (
497493 Handle, BinaryHandleGlobal,
498- Align (M.getDataLayout ().getPointerTypeSize (PointerType::getUnqual (C) )));
494+ Align (M.getDataLayout ().getPointerTypeSize (PtrTy )));
499495 CtorBuilder.CreateCall (createRegisterGlobalsFunction (M, IsHIP), Handle);
500496 if (!IsHIP)
501497 CtorBuilder.CreateCall (RegFatbinEnd, Handle);
@@ -507,8 +503,8 @@ void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
507503 // `atexit()` intead.
508504 IRBuilder<> DtorBuilder (BasicBlock::Create (C, " entry" , DtorFunc));
509505 LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad (
510- PointerType::getUnqual (C)-> getPointerTo () , BinaryHandleGlobal,
511- Align (M.getDataLayout ().getPointerTypeSize (PointerType::getUnqual (C) )));
506+ PtrTy , BinaryHandleGlobal,
507+ Align (M.getDataLayout ().getPointerTypeSize (PtrTy )));
512508 DtorBuilder.CreateCall (UnregFatbin, BinaryHandle);
513509 DtorBuilder.CreateRetVoid ();
514510
0 commit comments