Skip to content

Commit 5f68ea0

Browse files
committed
Add support for null fallback
1 parent d4bf656 commit 5f68ea0

File tree

6 files changed

+97
-39
lines changed

6 files changed

+97
-39
lines changed

offload/include/Shared/APITypes.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ struct KernelArgsTy {
102102
struct {
103103
uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
104104
uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
105-
uint64_t AllowDynCGroupMemFallback : 1; // Allow fallback for dynamic cgroup
106-
// mem fallback.
107-
uint64_t Unused : 61;
105+
uint64_t DynCGroupMemFallback : 2; // The fallback for dynamic cgroup mem.
106+
uint64_t Unused : 60;
108107
} Flags = {0, 0, 0, 0};
109108
// The number of teams (for x,y,z dimension).
110109
uint32_t NumTeams[3] = {0, 0, 0};

offload/include/Shared/Environment.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,25 @@ struct KernelEnvironmentTy {
9292
DynamicEnvironmentTy *DynamicEnv = nullptr;
9393
};
9494

95+
/// The fallback types for the dynamic cgroup memory.
96+
enum class DynCGroupMemFallbackType : unsigned char {
97+
/// None. Used for indicating that no fallback was triggered.
98+
None = 0,
99+
/// Abort the execution.
100+
Abort = None,
101+
/// Return null pointer.
102+
Null = 1,
103+
/// Allocate from a implementation defined memory space.
104+
DefaultMem = 2
105+
};
106+
95107
struct KernelLaunchEnvironmentTy {
96108
void *ReductionBuffer = nullptr;
97-
void *DynCGroupMemFallback = nullptr;
109+
void *DynCGroupMemFbPtr = nullptr;
98110
uint32_t ReductionCnt = 0;
99111
uint32_t ReductionIterCnt = 0;
100112
uint32_t DynCGroupMemSize = 0;
113+
DynCGroupMemFallbackType DynCGroupMemFb = DynCGroupMemFallbackType::None;
101114
};
102115

103116
#endif // OMPTARGET_SHARED_ENVIRONMENT_H

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ struct GenericKernelTy {
392392
/// Return a device pointer to a new kernel launch environment.
393393
Expected<KernelLaunchEnvironmentTy *> getKernelLaunchEnvironment(
394394
GenericDeviceTy &GenericDevice, const KernelArgsTy &KernelArgs,
395-
void *FallbackBlockMem, AsyncInfoWrapperTy &AsyncInfo) const;
395+
uint32_t BlockMemSize, DynCGroupMemFallbackType DynBlockMemFb,
396+
void *DynBlockMemFbPtr, AsyncInfoWrapperTy &AsyncInfoWrapper) const;
396397

397398
/// Indicate whether an execution mode is valid.
398399
static bool isValidExecutionMode(OMPTgtExecModeFlags ExecutionMode) {

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ Error GenericKernelTy::init(GenericDeviceTy &GenericDevice,
437437
Expected<KernelLaunchEnvironmentTy *>
438438
GenericKernelTy::getKernelLaunchEnvironment(
439439
GenericDeviceTy &GenericDevice, const KernelArgsTy &KernelArgs,
440-
void *FallbackBlockMem, AsyncInfoWrapperTy &AsyncInfoWrapper) const {
440+
uint32_t BlockMemSize, DynCGroupMemFallbackType DynBlockMemFb,
441+
void *DynBlockMemFbPtr, AsyncInfoWrapperTy &AsyncInfoWrapper) const {
441442
// Ctor/Dtor have no arguments, replaying uses the original kernel launch
442443
// environment. Older versions of the compiler do not generate a kernel
443444
// launch environment.
@@ -479,8 +480,9 @@ GenericKernelTy::getKernelLaunchEnvironment(
479480
LocalKLE.ReductionBuffer = nullptr;
480481
}
481482

482-
LocalKLE.DynCGroupMemSize = KernelArgs.DynCGroupMem;
483-
LocalKLE.DynCGroupMemFallback = FallbackBlockMem;
483+
LocalKLE.DynCGroupMemSize = BlockMemSize;
484+
LocalKLE.DynCGroupMemFbPtr = DynBlockMemFbPtr;
485+
LocalKLE.DynCGroupMemFb = DynBlockMemFb;
484486

485487
INFO(OMP_INFOTYPE_DATA_TRANSFER, GenericDevice.getDeviceId(),
486488
"Copying data from host to device, HstPtr=" DPxMOD ", TgtPtr=" DPxMOD
@@ -539,28 +541,43 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
539541
if (StaticBlockMemSize > MaxBlockMemSize)
540542
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
541543
"Static block memory size exceeds maximum");
542-
else if (!KernelArgs.Flags.AllowDynCGroupMemFallback &&
544+
else if (static_cast<DynCGroupMemFallbackType>(
545+
KernelArgs.Flags.DynCGroupMemFallback) ==
546+
DynCGroupMemFallbackType::Abort &&
543547
TotalBlockMemSize > MaxBlockMemSize)
544548
return Plugin::error(
545549
ErrorCode::INVALID_ARGUMENT,
546550
"Static and dynamic block memory size exceeds maximum");
547551

548-
void *FallbackBlockMem = nullptr;
552+
void *DynBlockMemFbPtr = nullptr;
553+
uint32_t DynBlockMemLaunchSize = DynBlockMemSize;
554+
555+
DynCGroupMemFallbackType DynBlockMemFb = DynCGroupMemFallbackType::None;
549556
if (DynBlockMemSize && (!GenericDevice.hasNativeBlockSharedMem() ||
550557
TotalBlockMemSize > MaxBlockMemSize)) {
551-
auto AllocOrErr = GenericDevice.dataAlloc(
552-
NumBlocks[0] * DynBlockMemSize,
553-
/*HostPtr=*/nullptr, TargetAllocTy::TARGET_ALLOC_DEVICE);
554-
if (!AllocOrErr)
555-
return AllocOrErr.takeError();
558+
// Launch without native dynamic block memory.
559+
DynBlockMemLaunchSize = 0;
560+
DynBlockMemFb = static_cast<DynCGroupMemFallbackType>(
561+
KernelArgs.Flags.DynCGroupMemFallback);
562+
if (DynBlockMemFb == DynCGroupMemFallbackType::DefaultMem) {
563+
// Get global memory as fallback.
564+
auto AllocOrErr = GenericDevice.dataAlloc(
565+
NumBlocks[0] * DynBlockMemSize,
566+
/*HostPtr=*/nullptr, TargetAllocTy::TARGET_ALLOC_DEVICE);
567+
if (!AllocOrErr)
568+
return AllocOrErr.takeError();
556569

557-
FallbackBlockMem = *AllocOrErr;
558-
AsyncInfoWrapper.freeAllocationAfterSynchronization(FallbackBlockMem);
559-
DynBlockMemSize = 0;
570+
DynBlockMemFbPtr = *AllocOrErr;
571+
AsyncInfoWrapper.freeAllocationAfterSynchronization(DynBlockMemFbPtr);
572+
} else {
573+
// Do not provide any memory as fallback.
574+
DynBlockMemSize = 0;
575+
}
560576
}
561577

562578
auto KernelLaunchEnvOrErr = getKernelLaunchEnvironment(
563-
GenericDevice, KernelArgs, FallbackBlockMem, AsyncInfoWrapper);
579+
GenericDevice, KernelArgs, DynBlockMemSize, DynBlockMemFb,
580+
DynBlockMemFbPtr, AsyncInfoWrapper);
564581
if (!KernelLaunchEnvOrErr)
565582
return KernelLaunchEnvOrErr.takeError();
566583

@@ -591,7 +608,7 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
591608
printLaunchInfo(GenericDevice, KernelArgs, NumThreads, NumBlocks))
592609
return Err;
593610

594-
return launchImpl(GenericDevice, NumThreads, NumBlocks, DynBlockMemSize,
611+
return launchImpl(GenericDevice, NumThreads, NumBlocks, DynBlockMemLaunchSize,
595612
KernelArgs, LaunchParams, AsyncInfoWrapper);
596613
}
597614

offload/test/offloading/dyn_groupprivate_strict.cpp renamed to offload/test/offloading/dyn_groupprivate.cpp

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
// RUN: %libomptarget-compilexx-run-and-check-generic
1+
// RUN: %libomptarget-compilexx-generic -fopenmp-version=61
2+
// RUN: %libomptarget-run-generic | %fcheck-generic
3+
// RUN: %libomptarget-compileoptxx-generic -fopenmp-version=61
4+
// RUN: %libomptarget-run-generic | %fcheck-generic
25
// REQUIRES: gpu
36

47
#include <omp.h>
@@ -9,8 +12,9 @@
912
int main() {
1013
int Result[N], NumThreads;
1114

15+
// Verify the groupprivate buffer works as expected.
1216
#pragma omp target teams num_teams(1) thread_limit(N) \
13-
dyn_groupprivate(strict : N * sizeof(Result[0])) \
17+
dyn_groupprivate(fallback(abort) : N * sizeof(Result[0])) \
1418
map(from : Result, NumThreads)
1519
{
1620
int Buffer[N];
@@ -51,8 +55,8 @@ int main() {
5155
size_t MaxSize = omp_get_groupprivate_limit(0, omp_access_cgroup);
5256
size_t ExceededSize = MaxSize + 10;
5357

54-
// Verify that the fallback modifier works.
55-
#pragma omp target dyn_groupprivate(fallback : ExceededSize) \
58+
// Verify that the fallback(default_mem) modifier works.
59+
#pragma omp target dyn_groupprivate(fallback(default_mem) : ExceededSize) \
5660
map(tofrom : Failed)
5761
{
5862
int IsFallback;
@@ -66,13 +70,35 @@ int main() {
6670
++Failed;
6771
}
6872

69-
// Verify that the default modifier is fallback.
73+
// Verify that the fallback(null) modifier works.
74+
#pragma omp target dyn_groupprivate(fallback(null) : ExceededSize) \
75+
map(tofrom : Failed)
76+
{
77+
int IsFallback;
78+
if ((TmpPtr = omp_get_dyn_groupprivate_ptr(0, &IsFallback)))
79+
++Failed;
80+
if ((TmpSize = omp_get_dyn_groupprivate_size()))
81+
++Failed;
82+
if (!IsFallback)
83+
++Failed;
84+
}
85+
86+
// Verify that the default modifier is fallback(default_mem).
7087
#pragma omp target dyn_groupprivate(ExceededSize)
7188
{
89+
int IsFallback;
90+
if (!omp_get_dyn_groupprivate_ptr(0, &IsFallback))
91+
++Failed;
92+
if (!omp_get_dyn_groupprivate_size())
93+
++Failed;
94+
if (omp_get_dyn_groupprivate_size() != ExceededSize)
95+
++Failed;
96+
if (!IsFallback)
97+
++Failed;
7298
}
7399

74-
// Verify that the strict modifier works.
75-
#pragma omp target dyn_groupprivate(strict : N) map(tofrom : Failed)
100+
// Verify that the fallback(abort) modifier works.
101+
#pragma omp target dyn_groupprivate(fallback(abort) : N) map(tofrom : Failed)
76102
{
77103
int IsFallback;
78104
if (!omp_get_dyn_groupprivate_ptr(0, &IsFallback))
@@ -85,8 +111,9 @@ int main() {
85111
++Failed;
86112
}
87113

88-
// Verify that the fallback does not trigger when not needed.
89-
#pragma omp target dyn_groupprivate(fallback : N) map(tofrom : Failed)
114+
// Verify that the fallback(default_mem) does not trigger when not needed.
115+
#pragma omp target dyn_groupprivate(fallback(default_mem) : N) \
116+
map(tofrom : Failed)
90117
{
91118
int IsFallback;
92119
if (!omp_get_dyn_groupprivate_ptr(0, &IsFallback))
@@ -100,7 +127,7 @@ int main() {
100127
}
101128

102129
// Verify that the clause works when passing a zero size.
103-
#pragma omp target dyn_groupprivate(strict : 0) map(tofrom : Failed)
130+
#pragma omp target dyn_groupprivate(fallback(abort) : 0) map(tofrom : Failed)
104131
{
105132
int IsFallback;
106133
if (omp_get_dyn_groupprivate_ptr(0, &IsFallback))
@@ -112,7 +139,8 @@ int main() {
112139
}
113140

114141
// Verify that the clause works when passing a zero size.
115-
#pragma omp target dyn_groupprivate(fallback : 0) map(tofrom : Failed)
142+
#pragma omp target dyn_groupprivate(fallback(default_mem) : 0) \
143+
map(tofrom : Failed)
116144
{
117145
int IsFallback;
118146
if (omp_get_dyn_groupprivate_ptr(0, &IsFallback))

openmp/device/src/State.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,27 +162,27 @@ struct DynCGroupMemTy {
162162
void init(KernelLaunchEnvironmentTy *KLE, void *NativeDynCGroup) {
163163
Size = 0;
164164
Ptr = nullptr;
165-
IsFallback = false;
165+
Fallback = DynCGroupMemFallbackType::None;
166166
if (!KLE)
167167
return;
168168

169169
Size = KLE->DynCGroupMemSize;
170-
if (void *Fallback = KLE->DynCGroupMemFallback) {
171-
Ptr = static_cast<char *>(Fallback) + Size * omp_get_team_num();
172-
IsFallback = true;
173-
} else {
170+
Fallback = KLE->DynCGroupMemFb;
171+
if (Fallback == DynCGroupMemFallbackType::None)
174172
Ptr = static_cast<char *>(NativeDynCGroup);
175-
}
173+
else if (Fallback == DynCGroupMemFallbackType::DefaultMem)
174+
Ptr = static_cast<char *>(KLE->DynCGroupMemFbPtr) +
175+
Size * omp_get_team_num();
176176
}
177177

178178
char *getPtr(size_t Offset) const { return Ptr + Offset; }
179-
bool isFallback() const { return IsFallback; }
179+
bool isFallback() const { return Fallback != DynCGroupMemFallbackType::None; }
180180
size_t getSize() const { return Size; }
181181

182182
private:
183183
char *Ptr;
184184
size_t Size;
185-
bool IsFallback;
185+
DynCGroupMemFallbackType Fallback;
186186
};
187187

188188
[[clang::loader_uninitialized]] static Local<DynCGroupMemTy> DynCGroupMem;

0 commit comments

Comments
 (0)