Skip to content

Commit e5d22be

Browse files
committed
Fix use_device_ptr codegen and test.
1 parent b84885a commit e5d22be

File tree

2 files changed

+82
-36
lines changed

2 files changed

+82
-36
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9039,6 +9039,8 @@ class MappableExprsHandler {
90399039

90409040
// Process each group in order of their attach-pointers increasing
90419041
// complexity.
9042+
std::optional<size_t> MemberOfValueForFirstCombinedEntry = std::nullopt;
9043+
bool IsFirstGroup = true;
90429044
for (const auto &Entry : AttachPtrGroups) {
90439045
const SmallVector<MapInfo, 8> &GroupLists = Entry.second;
90449046
if (GroupLists.empty())
@@ -9110,11 +9112,16 @@ class MappableExprsHandler {
91109112
MapCombinedInfoTy AttachCombinedInfo;
91119113
if (PartialStruct.Base.isValid()) {
91129114
CurInfo.append(PartialStruct.PreliminaryMapData);
9113-
emitCombinedEntry(
9115+
std::optional<size_t> CombinedEntryIndex = emitCombinedEntry(
91149116
CurInfo, AttachCombinedInfo, GroupCurInfo.Types, PartialStruct,
91159117
/*IsMapThis*/ !VD, OMPBuilder, VD,
91169118
/*OffsetForMemberOfFlag=*/CombinedInfo.BasePointers.size(),
91179119
/*NotTargetParam=*/true);
9120+
// Track the first group's combined entry's final-index for deferred
9121+
// entries to reference.
9122+
if (IsFirstGroup && CombinedEntryIndex.has_value())
9123+
MemberOfValueForFirstCombinedEntry =
9124+
CombinedInfo.BasePointers.size() + *CombinedEntryIndex;
91189125
}
91199126

91209127
// Append this group's results to the overall CurInfo in the correct
@@ -9123,12 +9130,15 @@ class MappableExprsHandler {
91239130
CurInfo.append(GroupStructBaseCurInfo);
91249131
CurInfo.append(GroupCurInfo);
91259132
CurInfo.append(AttachCombinedInfo);
9133+
9134+
IsFirstGroup = false;
91269135
}
91279136

91289137
// Append any pending zero-length pointers which are struct members and
91299138
// used with use_device_ptr or use_device_addr.
91309139
auto CI = DeferredInfo.find(Data.first);
91319140
if (CI != DeferredInfo.end()) {
9141+
size_t DeferredStartIdx = CurInfo.Types.size();
91329142
for (const DeferredDevicePtrEntryTy &L : CI->second) {
91339143
llvm::Value *BasePtr;
91349144
llvm::Value *Ptr;
@@ -9166,6 +9176,21 @@ class MappableExprsHandler {
91669176
llvm::Constant::getNullValue(this->CGF.Int64Ty));
91679177
CurInfo.Mappers.push_back(nullptr);
91689178
}
9179+
9180+
// Correct the MEMBER_OF flags for the deferred entries we just added.
9181+
if (MemberOfValueForFirstCombinedEntry.has_value() &&
9182+
DeferredStartIdx < CurInfo.Types.size()) {
9183+
// Use the tracked combined entry index from the first group
9184+
// Note that this assumes that the entries for use_device_ptr/addr
9185+
// should belong to the CombinedEntry emitted when handling the first
9186+
// "group". e.g. Even if we have `map(this->sp->a, this->sp->b)`, the
9187+
// CombinedEntry created for those, with `this->sp` as the attach-ptr,
9188+
// would not be the first attach-entry.
9189+
OpenMPOffloadMappingFlags MemberOfFlag =
9190+
OMPBuilder.getMemberOfFlag(*MemberOfValueForFirstCombinedEntry);
9191+
for (size_t I = DeferredStartIdx; I < CurInfo.Types.size(); ++I)
9192+
OMPBuilder.setCorrectMemberOfFlag(CurInfo.Types[I], MemberOfFlag);
9193+
}
91699194
}
91709195

91719196
// We need to append the results of this capture to what we already have.
@@ -9259,13 +9284,13 @@ class MappableExprsHandler {
92599284
/// individual struct members.
92609285
/// AttachCombinedInfo will be populated with ATTACH entries if
92619286
/// \p PartialStruct contains attach base-pointer information.
9262-
void emitCombinedEntry(MapCombinedInfoTy &CombinedInfo,
9263-
MapCombinedInfoTy &AttachCombinedInfo,
9264-
MapFlagsArrayTy &CurTypes,
9265-
const StructRangeInfoTy &PartialStruct, bool IsMapThis,
9266-
llvm::OpenMPIRBuilder &OMPBuilder, const ValueDecl *VD,
9267-
unsigned OffsetForMemberOfFlag,
9268-
bool NotTargetParams) const {
9287+
/// \returns The index of the combined entry if one was added, std::nullopt
9288+
/// otherwise.
9289+
std::optional<size_t> emitCombinedEntry(
9290+
MapCombinedInfoTy &CombinedInfo, MapCombinedInfoTy &AttachCombinedInfo,
9291+
MapFlagsArrayTy &CurTypes, const StructRangeInfoTy &PartialStruct,
9292+
bool IsMapThis, llvm::OpenMPIRBuilder &OMPBuilder, const ValueDecl *VD,
9293+
unsigned OffsetForMemberOfFlag, bool NotTargetParams) const {
92699294
if (CurTypes.size() == 1 &&
92709295
((CurTypes.back() & OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
92719296
OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) &&
@@ -9278,14 +9303,16 @@ class MappableExprsHandler {
92789303
PartialStruct.AttachPteeAddr,
92799304
PartialStruct.AttachPtrDecl,
92809305
PartialStruct.AttachMapExpr);
9281-
return;
9306+
return std::nullopt;
92829307
}
92839308
Address LBAddr = PartialStruct.LowestElem.second;
92849309
Address HBAddr = PartialStruct.HighestElem.second;
92859310
if (PartialStruct.HasCompleteRecord) {
92869311
LBAddr = PartialStruct.LB;
92879312
HBAddr = PartialStruct.LB;
92889313
}
9314+
// Capture the index where the combined entry will be inserted
9315+
size_t CombinedEntryIndex = CombinedInfo.BasePointers.size();
92899316
CombinedInfo.Exprs.push_back(VD);
92909317
// Base is the base of the struct
92919318
CombinedInfo.BasePointers.push_back(PartialStruct.Base.emitRawPointer(CGF));
@@ -9379,6 +9406,8 @@ class MappableExprsHandler {
93799406
addAttachEntry(CGF, AttachCombinedInfo, PartialStruct.AttachPtrAddr,
93809407
LBAddr, PartialStruct.AttachPtrDecl,
93819408
PartialStruct.AttachMapExpr);
9409+
9410+
return CombinedEntryIndex;
93829411
}
93839412

93849413
/// Generate all the base pointers, section pointers, sizes, map types, and
@@ -9636,7 +9665,7 @@ class MappableExprsHandler {
96369665
MapCombinedInfoTy AttachCombinedInfo;
96379666
if (PartialStruct.Base.isValid()) {
96389667
CurCaptureVarInfo.append(PartialStruct.PreliminaryMapData);
9639-
emitCombinedEntry(
9668+
(void)emitCombinedEntry(
96409669
CurCaptureVarInfo, AttachCombinedInfo,
96419670
CurInfoForComponentLists.Types, PartialStruct,
96429671
Cap->capturesThis(), OMPBuilder, nullptr, OffsetForMemberOfFlag,

clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ void foo(float *&lr, T *&tr) {
4141
float *l;
4242
T *t;
4343

44-
// &g[0], &g[/*lb=*/0], 10 * sizeof(g[0]), TO | FROM | RETURN_PARAM
45-
// &g, &g[/*lb=*/0], sizeof(g), ATTACH
44+
// &g[0], &g[0], 10 * sizeof(g[0]), TO | FROM | RETURN_PARAM
45+
// &g, &g[0], sizeof(void*), ATTACH
4646
//
4747
// CK1: [[T:%.+]] = load ptr, ptr [[DECL:@g]],
4848
// CK1: [[BP:%.+]] = getelementptr inbounds [2 x ptr], ptr %{{.+}}, i32 0, i32 0
@@ -62,8 +62,8 @@ void foo(float *&lr, T *&tr) {
6262
// CK1: getelementptr inbounds nuw double, ptr [[TTT]], i32 1
6363
++g;
6464

65-
// &l[0], &l[/*lb=*/0], 10 * sizeof(l[0]), TO | FROM | RETURN_PARAM
66-
// &l, &l[/*lb=*/0], sizeof(l), ATTACH
65+
// &l[0], &l[0], 10 * sizeof(l[0]), TO | FROM | RETURN_PARAM
66+
// &l, &l[0], sizeof(void*), ATTACH
6767
//
6868
// CK1: [[T1:%.+]] = load ptr, ptr [[DECL:%.+]],
6969
// CK1: [[BP:%.+]] = getelementptr inbounds [2 x ptr], ptr %{{.+}}, i32 0, i32 0
@@ -151,8 +151,8 @@ void foo(float *&lr, T *&tr) {
151151
// CK1: getelementptr inbounds nuw float, ptr [[TTT]], i32 1
152152
++l;
153153

154-
// &(ref_ptee(lr)[0]), &(ref_ptee(lr)[/*lb=*/0]), 10 * sizeof(lr[0]), TO | FROM | RETURN_PARAM
155-
// &(ref_ptee(lr)), &(ref_ptee(lr)[/*lb=*/0]), sizeof(ref_ptee(lr)), ATTACH
154+
// &(ptee(lr)[0]), &(ptee(lr)[0]), 10 * sizeof(lr[0]), TO | FROM | RETURN_PARAM
155+
// &(ptee(lr)), &(ptee(lr)[0]), sizeof(void*), ATTACH
156156
//
157157
// CK1: [[T2:%.+]] = load ptr, ptr [[DECL:%.+]],
158158
// CK1: [[T1:%.+]] = load ptr, ptr [[T2]],
@@ -176,8 +176,8 @@ void foo(float *&lr, T *&tr) {
176176
// CK1: getelementptr inbounds nuw float, ptr [[TTTT]], i32 1
177177
++lr;
178178

179-
// &t[0], &t[/*lb=*/0], 10 * sizeof(t[0]), TO | FROM | RETURN_PARAM
180-
// &t, &t[/*lb=*/0], sizeof(t), ATTACH
179+
// &t[0], &t[0], 10 * sizeof(t[0]), TO | FROM | RETURN_PARAM
180+
// &t, &t[1], sizeof(void*), ATTACH
181181
//
182182
// CK1: [[T1:%.+]] = load ptr, ptr [[DECL:%.+]],
183183
// CK1: [[BP:%.+]] = getelementptr inbounds [2 x ptr], ptr %{{.+}}, i32 0, i32 0
@@ -197,8 +197,8 @@ void foo(float *&lr, T *&tr) {
197197
// CK1: getelementptr inbounds nuw i32, ptr [[TTT]], i32 1
198198
++t;
199199

200-
// &(ref_ptee(tr)[0]), &(ref_ptee(tr)[/*lb=*/0]), 10 * sizeof(tr[0]), TO | FROM | RETURN_PARAM
201-
// &(ref_ptee(tr)), &(ref_ptee(tr)[/*lb=*/0]), sizeof(ref_ptee(tr)), ATTACH
200+
// &(ptee(tr)[0]), &(ptee(tr)[0]), 10 * sizeof(tr[0]), TO | FROM | RETURN_PARAM
201+
// &(ptee(tr)), &(ptee(tr)[0]), sizeof(void*), ATTACH
202202
//
203203
// CK1: [[T2:%.+]] = load ptr, ptr [[DECL:%.+]],
204204
// CK1: [[T1:%.+]] = load ptr, ptr [[T2]],
@@ -222,11 +222,11 @@ void foo(float *&lr, T *&tr) {
222222
// CK1: getelementptr inbounds nuw i32, ptr [[TTTT]], i32 1
223223
++tr;
224224

225-
// &l[0], &l[/*lb=*/0], 10 * sizeof(l[0]), TO | FROM [| RETURN_PARAM]
226-
// &l, &l[/*lb=*/0], sizeof(l), ATTACH
227-
// &t[0], &t[/*lb=*/0], 10 * sizeof(t[0]), TO | FROM [| RETURN_PARAM]
228-
// &t, &t[/*lb=*/0], sizeof(t), ATTACH
229-
225+
// &l[0], &l[0], 10 * sizeof(l[0]), TO | FROM [| RETURN_PARAM]
226+
// &l, &l[0], sizeof(void*), ATTACH
227+
// &t[0], &t[0], 10 * sizeof(t[0]), TO | FROM [| RETURN_PARAM]
228+
// &t, &t[0], sizeof(void*), ATTACH
229+
//
230230
// CK1: [[T1:%.+]] = load ptr, ptr [[DECL:%.+]],
231231
// CK1: [[BP:%.+]] = getelementptr inbounds [4 x ptr], ptr %{{.+}}, i32 0, i32 0
232232
// CK1: store ptr [[T1]], ptr [[BP]],
@@ -286,9 +286,10 @@ void foo(float *&lr, T *&tr) {
286286
// CK1: getelementptr inbounds nuw i32, ptr [[TTT]], i32 1
287287
++l; ++t;
288288

289-
// &l[0], &l[/*lb=*/0], 10 * sizeof(l[0]), TO | FROM [| RETURN_PARAM]
290-
// &l, &l[/*lb=*/0], sizeof(l), ATTACH
291-
// &t[0], &t[/*lb=*/0], 0, RETURN_PARAM
289+
// &l[0], &l[0], 10 * sizeof(l[0]), TO | FROM [| RETURN_PARAM]
290+
// &l, &l[0], sizeof(void*), ATTACH
291+
// &t[0], &t[0], 0, RETURN_PARAM
292+
//
292293
// CK1: [[T1:%.+]] = load ptr, ptr [[DECL:%.+]],
293294
// CK1: [[BP:%.+]] = getelementptr inbounds [3 x ptr], ptr %{{.+}}, i32 0, i32 2
294295
// CK1: store ptr [[T1]], ptr [[BP]],
@@ -354,10 +355,10 @@ void bar(float *&a, int *&b) {
354355
#ifdef CK2
355356

356357
// CK2: [[ST:%.+]] = type { ptr, ptr }
357-
// CK2: [[MTYPE00:@.+]] = {{.*}}constant [2 x i64] [i64 0, i64 281474976710739]
358-
// CK2: [[MTYPE01:@.+]] = {{.*}}constant [2 x i64] [i64 0, i64 281474976710739]
359-
// CK2: [[MTYPE02:@.+]] = {{.*}}constant [4 x i64] [i64 3, i64 16384, i64 0, i64 844424930132048]
360-
// CK2: [[MTYPE03:@.+]] = {{.*}}constant [3 x i64] [i64 0, i64 281474976710739, i64 281474976710736]
358+
// CK2: [[MTYPE00:@.+]] = {{.*}}constant [2 x i64] [i64 [[#0x43]], i64 [[#0x4000]]]
359+
// CK2: [[MTYPE01:@.+]] = {{.*}}constant [2 x i64] [i64 [[#0x43]], i64 [[#0x4000]]]
360+
// CK2: [[MTYPE02:@.+]] = {{.*}}constant [4 x i64] [i64 3, i64 [[#0x4000]], i64 0, i64 [[#0x3000000000050]]]
361+
// CK2: [[MTYPE03:@.+]] = {{.*}}constant [4 x i64] [i64 0, i64 [[#0x43]], i64 [[#0x4000]], i64 [[#0x1000000000050]]]
361362

362363
template <typename T>
363364
struct ST {
@@ -369,7 +370,10 @@ struct ST {
369370
void foo(double *&arg) {
370371
int *la = 0;
371372

372-
// CK2: [[BP:%.+]] = getelementptr inbounds [2 x ptr], ptr %{{.+}}, i32 0, i32 1
373+
// &a[0], &a[0], 10 * sizeof(a[0]), TO | FROM | RETURN_PARAM
374+
// &a, &a[0], sizeof(void*), ATTACH
375+
//
376+
// CK2: [[BP:%.+]] = getelementptr inbounds [2 x ptr], ptr %{{.+}}, i32 0, i32 0
373377
// CK2: store ptr [[RVAL:%.+]], ptr [[BP]],
374378
// CK2: call void @__tgt_target_data_begin{{.+}}[[MTYPE00]]
375379
// CK2: [[VAL:%.+]] = load ptr, ptr [[BP]],
@@ -388,7 +392,10 @@ struct ST {
388392
// CK2: getelementptr inbounds nuw double, ptr [[TTT]], i32 1
389393
a++;
390394

391-
// CK2: [[BP:%.+]] = getelementptr inbounds [2 x ptr], ptr %{{.+}}, i32 0, i32 1
395+
// &ptee(b)[0], &ptee(b)[0], 10 * sizeof(ptee(b)[0]), TO | FROM | RETURN_PARAM
396+
// &ptee(b), &ptee(b)[0], sizeof(void*), ATTACH
397+
//
398+
// CK2: [[BP:%.+]] = getelementptr inbounds [2 x ptr], ptr %{{.+}}, i32 0, i32 0
392399
// CK2: store ptr [[RVAL:%.+]], ptr [[BP]],
393400
// CK2: call void @__tgt_target_data_begin{{.+}}[[MTYPE01]]
394401
// CK2: [[VAL:%.+]] = load ptr, ptr [[BP]],
@@ -408,6 +415,11 @@ struct ST {
408415
// CK2: getelementptr inbounds nuw double, ptr [[TTTT]], i32 1
409416
b++;
410417

418+
// &la[0], &la[0], 10 * sizeof(la[0]), TO | FROM
419+
// &la, &la[0], sizeof(void*), ATTACH
420+
// &this[0], &this[0].a, sizeof(this[0].a), ALLOC
421+
// &this[0], &this[0].a[0], 0, MEMBER_OF_3 | RETURN_PARAM
422+
//
411423
// CK2: [[BP:%.+]] = getelementptr inbounds [4 x ptr], ptr %{{.+}}, i32 0, i32 3
412424
// CK2: store ptr [[RVAL:%.+]], ptr [[BP]],
413425
// CK2: call void @__tgt_target_data_begin{{.+}}[[MTYPE02]]
@@ -429,9 +441,14 @@ struct ST {
429441
a++;
430442
la++;
431443

432-
// CK2: [[BP1:%.+]] = getelementptr inbounds [3 x ptr], ptr %{{.+}}, i32 0, i32 1
444+
// &this[0], &this[0].a, sizeof(this[0].a), ALLOC
445+
// &ptee(b)[0], &ptee(b)[0], 10 * sizeof(ptee(b)[0]), TO | FROM | RETURN_PARAM
446+
// &ptee(b), &ptee(b)[0], sizeof(void*), ATTACH
447+
// &this[0], &this[0].a[0], 0, MEMBER_OF_1 | RETURN_PARAM
448+
//
449+
// CK2: [[BP1:%.+]] = getelementptr inbounds [4 x ptr], ptr %{{.+}}, i32 0, i32 1
433450
// CK2: store ptr [[RVAL1:%.+]], ptr [[BP1]],
434-
// CK2: [[BP2:%.+]] = getelementptr inbounds [3 x ptr], ptr %{{.+}}, i32 0, i32 2
451+
// CK2: [[BP2:%.+]] = getelementptr inbounds [4 x ptr], ptr %{{.+}}, i32 0, i32 3
435452
// CK2: store ptr [[RVAL2:%.+]], ptr [[BP2]],
436453
// CK2: call void @__tgt_target_data_begin{{.+}}[[MTYPE03]]
437454
// CK2: [[VAL1:%.+]] = load ptr, ptr [[BP1]],

0 commit comments

Comments
 (0)