@@ -471,12 +471,13 @@ struct FunctionOptions {
471
471
const StringRef FunctionName;
472
472
// / Location of the non-debug version of the outlined function.
473
473
SourceLocation Loc;
474
+ const bool IsDeviceKernel = false ;
474
475
explicit FunctionOptions (const CapturedStmt *S, bool UIntPtrCastRequired,
475
476
bool RegisterCastedArgsOnly, StringRef FunctionName,
476
- SourceLocation Loc)
477
+ SourceLocation Loc, bool IsDeviceKernel )
477
478
: S(S), UIntPtrCastRequired(UIntPtrCastRequired),
478
479
RegisterCastedArgsOnly(UIntPtrCastRequired && RegisterCastedArgsOnly),
479
- FunctionName(FunctionName), Loc(Loc) {}
480
+ FunctionName(FunctionName), Loc(Loc), IsDeviceKernel(IsDeviceKernel) {}
480
481
};
481
482
} // namespace
482
483
@@ -570,7 +571,11 @@ static llvm::Function *emitOutlinedFunctionPrologue(
570
571
571
572
// Create the function declaration.
572
573
const CGFunctionInfo &FuncInfo =
573
- CGM.getTypes ().arrangeBuiltinFunctionDeclaration (Ctx.VoidTy , TargetArgs);
574
+ FO.IsDeviceKernel
575
+ ? CGM.getTypes ().arrangeDeviceKernelCallerDeclaration (Ctx.VoidTy ,
576
+ TargetArgs)
577
+ : CGM.getTypes ().arrangeBuiltinFunctionDeclaration (Ctx.VoidTy ,
578
+ TargetArgs);
574
579
llvm::FunctionType *FuncLLVMTy = CGM.getTypes ().GetFunctionType (FuncInfo);
575
580
576
581
auto *F =
@@ -664,9 +669,9 @@ static llvm::Function *emitOutlinedFunctionPrologue(
664
669
return F;
665
670
}
666
671
667
- llvm::Function *
668
- CodeGenFunction::GenerateOpenMPCapturedStmtFunction ( const CapturedStmt &S,
669
- SourceLocation Loc) {
672
+ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction (
673
+ const CapturedStmt &S, const OMPExecutableDirective &D) {
674
+ SourceLocation Loc = D. getBeginLoc ();
670
675
assert (
671
676
CapturedStmtInfo &&
672
677
" CapturedStmtInfo should be set when generating the captured function" );
@@ -682,23 +687,27 @@ CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
682
687
SmallString<256 > Buffer;
683
688
llvm::raw_svector_ostream Out (Buffer);
684
689
Out << CapturedStmtInfo->getHelperName ();
685
-
690
+ OpenMPDirectiveKind EKind = getEffectiveDirectiveKind (D);
691
+ bool IsDeviceKernel = CGM.getOpenMPRuntime ().isGPU () &&
692
+ isOpenMPTargetExecutionDirective (EKind) &&
693
+ D.getCapturedStmt (OMPD_target) == &S;
686
694
CodeGenFunction WrapperCGF (CGM, /* suppressNewContext=*/ true );
687
695
llvm::Function *WrapperF = nullptr ;
688
696
if (NeedWrapperFunction) {
689
697
// Emit the final kernel early to allow attributes to be added by the
690
698
// OpenMPI-IR-Builder.
691
699
FunctionOptions WrapperFO (&S, /* UIntPtrCastRequired=*/ true ,
692
700
/* RegisterCastedArgsOnly=*/ true ,
693
- CapturedStmtInfo->getHelperName (), Loc);
701
+ CapturedStmtInfo->getHelperName (), Loc,
702
+ IsDeviceKernel);
694
703
WrapperCGF.CapturedStmtInfo = CapturedStmtInfo;
695
704
WrapperF =
696
705
emitOutlinedFunctionPrologue (WrapperCGF, Args, LocalAddrs, VLASizes,
697
706
WrapperCGF.CXXThisValue , WrapperFO);
698
707
Out << " _debug__" ;
699
708
}
700
709
FunctionOptions FO (&S, !NeedWrapperFunction, /* RegisterCastedArgsOnly=*/ false ,
701
- Out.str (), Loc);
710
+ Out.str (), Loc, !NeedWrapperFunction && IsDeviceKernel );
702
711
llvm::Function *F = emitOutlinedFunctionPrologue (
703
712
*this , WrapperArgs, WrapperLocalAddrs, WrapperVLASizes, CXXThisValue, FO);
704
713
CodeGenFunction::OMPPrivateScope LocalScope (*this );
@@ -6119,13 +6128,13 @@ void CodeGenFunction::EmitOMPDistributeDirective(
6119
6128
emitOMPDistributeDirective (S, *this , CGM);
6120
6129
}
6121
6130
6122
- static llvm::Function *emitOutlinedOrderedFunction (CodeGenModule &CGM,
6123
- const CapturedStmt *S,
6124
- SourceLocation Loc ) {
6131
+ static llvm::Function *
6132
+ emitOutlinedOrderedFunction (CodeGenModule &CGM, const CapturedStmt *S,
6133
+ const OMPExecutableDirective &D ) {
6125
6134
CodeGenFunction CGF (CGM, /* suppressNewContext=*/ true );
6126
6135
CodeGenFunction::CGCapturedStmtInfo CapStmtInfo;
6127
6136
CGF.CapturedStmtInfo = &CapStmtInfo;
6128
- llvm::Function *Fn = CGF.GenerateOpenMPCapturedStmtFunction (*S, Loc );
6137
+ llvm::Function *Fn = CGF.GenerateOpenMPCapturedStmtFunction (*S, D );
6129
6138
Fn->setDoesNotRecurse ();
6130
6139
return Fn;
6131
6140
}
@@ -6190,8 +6199,7 @@ void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
6190
6199
Builder, /* CreateBranch=*/ false , " .ordered.after" );
6191
6200
llvm::SmallVector<llvm::Value *, 16 > CapturedVars;
6192
6201
GenerateOpenMPCapturedVars (*CS, CapturedVars);
6193
- llvm::Function *OutlinedFn =
6194
- emitOutlinedOrderedFunction (CGM, CS, S.getBeginLoc ());
6202
+ llvm::Function *OutlinedFn = emitOutlinedOrderedFunction (CGM, CS, S);
6195
6203
assert (S.getBeginLoc ().isValid () &&
6196
6204
" Outlined function call location must be valid." );
6197
6205
ApplyDebugLocation::CreateDefaultArtificial (*this , S.getBeginLoc ());
@@ -6233,8 +6241,7 @@ void CodeGenFunction::EmitOMPOrderedDirective(const OMPOrderedDirective &S) {
6233
6241
if (C) {
6234
6242
llvm::SmallVector<llvm::Value *, 16 > CapturedVars;
6235
6243
CGF.GenerateOpenMPCapturedVars (*CS, CapturedVars);
6236
- llvm::Function *OutlinedFn =
6237
- emitOutlinedOrderedFunction (CGM, CS, S.getBeginLoc ());
6244
+ llvm::Function *OutlinedFn = emitOutlinedOrderedFunction (CGM, CS, S);
6238
6245
CGM.getOpenMPRuntime ().emitOutlinedFunctionCall (CGF, S.getBeginLoc (),
6239
6246
OutlinedFn, CapturedVars);
6240
6247
} else {
0 commit comments