Skip to content
51 changes: 38 additions & 13 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10000,19 +10000,44 @@ static llvm::Value *emitDeviceID(
return DeviceID;
}

static llvm::Value *emitDynCGGroupMem(const OMPExecutableDirective &D,
CodeGenFunction &CGF) {
llvm::Value *DynCGroupMem = CGF.Builder.getInt32(0);

if (auto *DynMemClause = D.getSingleClause<OMPXDynCGroupMemClause>()) {
CodeGenFunction::RunCleanupsScope DynCGroupMemScope(CGF);
llvm::Value *DynCGroupMemVal = CGF.EmitScalarExpr(
DynMemClause->getSize(), /*IgnoreResultAssign=*/true);
DynCGroupMem = CGF.Builder.CreateIntCast(DynCGroupMemVal, CGF.Int32Ty,
/*isSigned=*/false);
static std::pair<llvm::Value *, OMPDynGroupprivateFallbackType>
emitDynCGroupMem(const OMPExecutableDirective &D, CodeGenFunction &CGF) {
llvm::Value *DynGP = CGF.Builder.getInt32(0);
auto DynGPFallback = OMPDynGroupprivateFallbackType::Abort;

if (auto *DynGPClause = D.getSingleClause<OMPDynGroupprivateClause>()) {
CodeGenFunction::RunCleanupsScope DynGPScope(CGF);
llvm::Value *DynGPVal =
CGF.EmitScalarExpr(DynGPClause->getSize(), /*IgnoreResultAssign=*/true);
DynGP = CGF.Builder.CreateIntCast(DynGPVal, CGF.Int32Ty,
/*isSigned=*/false);
auto FallbackModifier = DynGPClause->getDynGroupprivateFallbackModifier();
switch (FallbackModifier) {
case OMPC_DYN_GROUPPRIVATE_FALLBACK_abort:
DynGPFallback = OMPDynGroupprivateFallbackType::Abort;
break;
case OMPC_DYN_GROUPPRIVATE_FALLBACK_null:
DynGPFallback = OMPDynGroupprivateFallbackType::Null;
break;
case OMPC_DYN_GROUPPRIVATE_FALLBACK_default_mem:
case OMPC_DYN_GROUPPRIVATE_FALLBACK_unknown:
// This is the default for dyn_groupprivate.
DynGPFallback = OMPDynGroupprivateFallbackType::DefaultMem;
break;
default:
llvm_unreachable("Unknown fallback modifier for OpenMP dyn_groupprivate");
}
} else if (auto *OMPXDynCGClause =
D.getSingleClause<OMPXDynCGroupMemClause>()) {
CodeGenFunction::RunCleanupsScope DynCGMemScope(CGF);
llvm::Value *DynCGMemVal = CGF.EmitScalarExpr(OMPXDynCGClause->getSize(),
/*IgnoreResultAssign=*/true);
DynGP = CGF.Builder.CreateIntCast(DynCGMemVal, CGF.Int32Ty,
/*isSigned=*/false);
}
return DynCGroupMem;
return {DynGP, DynGPFallback};
}

static void genMapInfoForCaptures(
MappableExprsHandler &MEHandler, CodeGenFunction &CGF,
const CapturedStmt &CS, llvm::SmallVectorImpl<llvm::Value *> &CapturedVars,
Expand Down Expand Up @@ -10221,7 +10246,7 @@ static void emitTargetCallKernelLaunch(
llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());
llvm::Value *NumIterations =
OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter);
llvm::Value *DynCGGroupMem = emitDynCGGroupMem(D, CGF);
auto [DynCGroupMem, DynCGroupMemFallback] = emitDynCGroupMem(D, CGF);
llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());

Expand All @@ -10231,7 +10256,7 @@ static void emitTargetCallKernelLaunch(

llvm::OpenMPIRBuilder::TargetKernelArgs Args(
NumTargetItems, RTArgs, NumIterations, NumTeams, NumThreads,
DynCGGroupMem, HasNoWait);
DynCGroupMem, HasNoWait, DynCGroupMemFallback);

llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
cantFail(OMPRuntime->getOMPBuilder().emitKernelLaunch(
Expand Down
Loading
Loading