Skip to content

Commit f66e5fa

Browse files
committed
[OpenMP] Add codegen support for dyn_groupprivate clause
1 parent 099c502 commit f66e5fa

File tree

3 files changed

+39
-18
lines changed

3 files changed

+39
-18
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9489,18 +9489,30 @@ static llvm::Value *emitDeviceID(
94899489
return DeviceID;
94909490
}
94919491

9492-
static llvm::Value *emitDynCGGroupMem(const OMPExecutableDirective &D,
9493-
CodeGenFunction &CGF) {
9494-
llvm::Value *DynCGroupMem = CGF.Builder.getInt32(0);
9495-
9496-
if (auto *DynMemClause = D.getSingleClause<OMPXDynCGroupMemClause>()) {
9497-
CodeGenFunction::RunCleanupsScope DynCGroupMemScope(CGF);
9498-
llvm::Value *DynCGroupMemVal = CGF.EmitScalarExpr(
9499-
DynMemClause->getSize(), /*IgnoreResultAssign=*/true);
9500-
DynCGroupMem = CGF.Builder.CreateIntCast(DynCGroupMemVal, CGF.Int32Ty,
9501-
/*isSigned=*/false);
9502-
}
9503-
return DynCGroupMem;
9492+
static std::pair<llvm::Value *, bool>
9493+
emitDynCGroupMem(const OMPExecutableDirective &D, CodeGenFunction &CGF) {
9494+
llvm::Value *DynGP = CGF.Builder.getInt32(0);
9495+
bool DynGPFallback = false;
9496+
9497+
if (auto *DynGPClause = D.getSingleClause<OMPDynGroupprivateClause>()) {
9498+
CodeGenFunction::RunCleanupsScope DynGPScope(CGF);
9499+
llvm::Value *DynGPVal =
9500+
CGF.EmitScalarExpr(DynGPClause->getSize(), /*IgnoreResultAssign=*/true);
9501+
DynGP = CGF.Builder.CreateIntCast(DynGPVal, CGF.Int32Ty,
9502+
/*isSigned=*/false);
9503+
DynGPFallback = (DynGPClause->getFirstDynGroupprivateModifier() !=
9504+
OMPC_DYN_GROUPPRIVATE_strict &&
9505+
DynGPClause->getSecondDynGroupprivateModifier() !=
9506+
OMPC_DYN_GROUPPRIVATE_strict);
9507+
} else if (auto *OMPXDynCGClause =
9508+
D.getSingleClause<OMPXDynCGroupMemClause>()) {
9509+
CodeGenFunction::RunCleanupsScope DynCGMemScope(CGF);
9510+
llvm::Value *DynCGMemVal = CGF.EmitScalarExpr(OMPXDynCGClause->getSize(),
9511+
/*IgnoreResultAssign=*/true);
9512+
DynGP = CGF.Builder.CreateIntCast(DynCGMemVal, CGF.Int32Ty,
9513+
/*isSigned=*/false);
9514+
}
9515+
return {DynGP, DynGPFallback};
95049516
}
95059517
static void genMapInfoForCaptures(
95069518
MappableExprsHandler &MEHandler, CodeGenFunction &CGF,
@@ -9710,7 +9722,7 @@ static void emitTargetCallKernelLaunch(
97109722
llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());
97119723
llvm::Value *NumIterations =
97129724
OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter);
9713-
llvm::Value *DynCGGroupMem = emitDynCGGroupMem(D, CGF);
9725+
auto [DynCGroupMem, DynCGroupMemFallback] = emitDynCGroupMem(D, CGF);
97149726
llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
97159727
CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());
97169728

@@ -9720,7 +9732,7 @@ static void emitTargetCallKernelLaunch(
97209732

97219733
llvm::OpenMPIRBuilder::TargetKernelArgs Args(
97229734
NumTargetItems, RTArgs, NumIterations, NumTeams, NumThreads,
9723-
DynCGGroupMem, HasNoWait);
9735+
DynCGroupMem, HasNoWait, DynCGroupMemFallback);
97249736

97259737
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
97269738
cantFail(OMPRuntime->getOMPBuilder().emitKernelLaunch(

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,17 +2341,20 @@ class OpenMPIRBuilder {
23412341
Value *DynCGGroupMem = nullptr;
23422342
/// True if the kernel has 'no wait' clause.
23432343
bool HasNoWait = false;
2344+
/// True if the dynamic shared memory may fallback.
2345+
bool MayFallbackDynCGroupMem = false;
23442346

23452347
// Constructors for TargetKernelArgs.
23462348
TargetKernelArgs() {}
23472349
TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
23482350
Value *NumIterations, ArrayRef<Value *> NumTeams,
23492351
ArrayRef<Value *> NumThreads, Value *DynCGGroupMem,
2350-
bool HasNoWait)
2352+
bool HasNoWait, bool MayFallbackDynCGroupMem)
23512353
: NumTargetItems(NumTargetItems), RTArgs(RTArgs),
23522354
NumIterations(NumIterations), NumTeams(NumTeams),
23532355
NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
2354-
HasNoWait(HasNoWait) {}
2356+
HasNoWait(HasNoWait),
2357+
MayFallbackDynCGroupMem(MayFallbackDynCGroupMem) {}
23552358
};
23562359

23572360
/// Create the kernel args vector used by emitTargetKernel. This function

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,13 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
506506
auto Int32Ty = Type::getInt32Ty(Builder.getContext());
507507
constexpr const size_t MaxDim = 3;
508508
Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, MaxDim));
509-
Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
509+
510+
Value *HasNoWaitFlag = Builder.getInt64(KernelArgs.HasNoWait);
511+
Value *MayFallbackDynCGroupMemFlag =
512+
Builder.getInt64(KernelArgs.MayFallbackDynCGroupMem);
513+
MayFallbackDynCGroupMemFlag =
514+
Builder.CreateShl(MayFallbackDynCGroupMemFlag, 2);
515+
Value *Flags = Builder.CreateOr(HasNoWaitFlag, MayFallbackDynCGroupMemFlag);
510516

511517
assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
512518

@@ -7891,7 +7897,7 @@ static void emitTargetCall(
78917897

78927898
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
78937899
NumTeamsC, NumThreadsC,
7894-
DynCGGroupMem, HasNoWait);
7900+
DynCGGroupMem, HasNoWait, false);
78957901

78967902
// Assume no error was returned because TaskBodyCB and
78977903
// EmitTargetCallFallbackCB don't produce any.

0 commit comments

Comments
 (0)