@@ -471,12 +471,13 @@ struct FunctionOptions {
471471 const StringRef FunctionName;
472472 // / Location of the non-debug version of the outlined function.
473473 SourceLocation Loc;
474+ const bool IsDeviceKernel = false ;
474475 explicit FunctionOptions (const CapturedStmt *S, bool UIntPtrCastRequired,
475476 bool RegisterCastedArgsOnly, StringRef FunctionName,
476- SourceLocation Loc)
477+ SourceLocation Loc, bool IsDeviceKernel )
477478 : S(S), UIntPtrCastRequired(UIntPtrCastRequired),
478479 RegisterCastedArgsOnly(UIntPtrCastRequired && RegisterCastedArgsOnly),
479- FunctionName(FunctionName), Loc(Loc) {}
480+ FunctionName(FunctionName), Loc(Loc), IsDeviceKernel(IsDeviceKernel) {}
480481};
481482} // namespace
482483
@@ -570,7 +571,11 @@ static llvm::Function *emitOutlinedFunctionPrologue(
570571
571572 // Create the function declaration.
572573 const CGFunctionInfo &FuncInfo =
573- CGM.getTypes ().arrangeBuiltinFunctionDeclaration (Ctx.VoidTy , TargetArgs);
574+ FO.IsDeviceKernel
575+ ? CGM.getTypes ().arrangeDeviceKernelCallerDeclaration (Ctx.VoidTy ,
576+ TargetArgs)
577+ : CGM.getTypes ().arrangeBuiltinFunctionDeclaration (Ctx.VoidTy ,
578+ TargetArgs);
574579 llvm::FunctionType *FuncLLVMTy = CGM.getTypes ().GetFunctionType (FuncInfo);
575580
576581 auto *F =
@@ -664,9 +669,9 @@ static llvm::Function *emitOutlinedFunctionPrologue(
664669 return F;
665670}
666671
667- llvm::Function *
668- CodeGenFunction::GenerateOpenMPCapturedStmtFunction ( const CapturedStmt &S,
669- SourceLocation Loc) {
672+ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction (
673+ const CapturedStmt &S, const OMPExecutableDirective &D) {
674+ SourceLocation Loc = D. getBeginLoc ();
670675 assert (
671676 CapturedStmtInfo &&
672677 " CapturedStmtInfo should be set when generating the captured function" );
@@ -682,23 +687,27 @@ CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
682687 SmallString<256 > Buffer;
683688 llvm::raw_svector_ostream Out (Buffer);
684689 Out << CapturedStmtInfo->getHelperName ();
685-
690+ OpenMPDirectiveKind EKind = getEffectiveDirectiveKind (D);
691+ bool IsDeviceKernel = CGM.getOpenMPRuntime ().isGPU () &&
692+ isOpenMPTargetExecutionDirective (EKind) &&
693+ D.getCapturedStmt (OMPD_target) == &S;
686694 CodeGenFunction WrapperCGF (CGM, /* suppressNewContext=*/ true );
687695 llvm::Function *WrapperF = nullptr ;
688696 if (NeedWrapperFunction) {
689697 // Emit the final kernel early to allow attributes to be added by the
690698 // OpenMPI-IR-Builder.
691699 FunctionOptions WrapperFO (&S, /* UIntPtrCastRequired=*/ true ,
692700 /* RegisterCastedArgsOnly=*/ true ,
693- CapturedStmtInfo->getHelperName (), Loc);
701+ CapturedStmtInfo->getHelperName (), Loc,
702+ IsDeviceKernel);
694703 WrapperCGF.CapturedStmtInfo = CapturedStmtInfo;
695704 WrapperF =
696705 emitOutlinedFunctionPrologue (WrapperCGF, Args, LocalAddrs, VLASizes,
697706 WrapperCGF.CXXThisValue , WrapperFO);
698707 Out << " _debug__" ;
699708 }
700709 FunctionOptions FO (&S, !NeedWrapperFunction, /* RegisterCastedArgsOnly=*/ false ,
701- Out.str (), Loc);
710+ Out.str (), Loc, !NeedWrapperFunction && IsDeviceKernel );
702711 llvm::Function *F = emitOutlinedFunctionPrologue (
703712 *this , WrapperArgs, WrapperLocalAddrs, WrapperVLASizes, CXXThisValue, FO);
704713 CodeGenFunction::OMPPrivateScope LocalScope (*this );
@@ -6119,13 +6128,13 @@ void CodeGenFunction::EmitOMPDistributeDirective(
61196128 emitOMPDistributeDirective (S, *this , CGM);
61206129}
61216130
6122- static llvm::Function *emitOutlinedOrderedFunction (CodeGenModule &CGM,
6123- const CapturedStmt *S,
6124- SourceLocation Loc ) {
6131+ static llvm::Function *
6132+ emitOutlinedOrderedFunction (CodeGenModule &CGM, const CapturedStmt *S,
6133+ const OMPExecutableDirective &D ) {
61256134 CodeGenFunction CGF (CGM, /* suppressNewContext=*/ true );
61266135 CodeGenFunction::CGCapturedStmtInfo CapStmtInfo;
61276136 CGF.CapturedStmtInfo = &CapStmtInfo;
6128- llvm::Function *Fn = CGF.GenerateOpenMPCapturedStmtFunction (*S, Loc );
6137+ llvm::Function *Fn = CGF.GenerateOpenMPCapturedStmtFunction (*S, D );
61296138 Fn->setDoesNotRecurse ();
61306139 return Fn;
61316140}
@@ -6190,8 +6199,7 @@ void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
61906199 Builder, /* CreateBranch=*/ false , " .ordered.after" );
61916200 llvm::SmallVector<llvm::Value *, 16 > CapturedVars;
61926201 GenerateOpenMPCapturedVars (*CS, CapturedVars);
6193- llvm::Function *OutlinedFn =
6194- emitOutlinedOrderedFunction (CGM, CS, S.getBeginLoc ());
6202+ llvm::Function *OutlinedFn = emitOutlinedOrderedFunction (CGM, CS, S);
61956203 assert (S.getBeginLoc ().isValid () &&
61966204 " Outlined function call location must be valid." );
61976205 ApplyDebugLocation::CreateDefaultArtificial (*this , S.getBeginLoc ());
@@ -6233,8 +6241,7 @@ void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
62336241 if (C) {
62346242 llvm::SmallVector<llvm::Value *, 16 > CapturedVars;
62356243 CGF.GenerateOpenMPCapturedVars (*CS, CapturedVars);
6236- llvm::Function *OutlinedFn =
6237- emitOutlinedOrderedFunction (CGM, CS, S.getBeginLoc ());
6244+ llvm::Function *OutlinedFn = emitOutlinedOrderedFunction (CGM, CS, S);
62386245 CGM.getOpenMPRuntime ().emitOutlinedFunctionCall (CGF, S.getBeginLoc (),
62396246 OutlinedFn, CapturedVars);
62406247 } else {
0 commit comments