@@ -224,19 +224,10 @@ class MapInfoFinalizationPass
224224 }
225225
226226 // / Adjusts the descriptor's map type. The main alteration that is done
227- // / currently is transforming the map type to `OMP_MAP_TO` where possible,
228- // / plus adding OMP_MAP_ALWAYS flag. Descriptors will always be copied,
229- // / even if the object was listed on the `has_device_addr` clause.
230- // / This is because the descriptor can be rematerialized by the compiler,
231- // / and so the address of the descriptor for a given object at one place in
232- // / the code may differ from that address in another place. The contents
233- // / of the descriptor (the base address in particular) will remain unchanged
234- // / though. Non-descriptor objects listed on the `has_device_addr` clause
235- // / can be passed to the kernel by just passing their address without any
236- // / remapping.
237- // / The adding of the OMP_MAP_TO flag is done because we will always need to
238- // / map the descriptor to device, especially without device address clauses,
239- // / as without the appropriate descriptor
227+ // / currently is transforming the map type to `OMP_MAP_TO` where possible.
228+ // / This is because we will always need to map the descriptor to device
229+ // / (or at the very least it seems to be the case currently with the
230+ // / current lowered kernel IR), as without the appropriate descriptor
240231 // / information on the device there is a risk of the kernel IR
241232 // / requesting for various data that will not have been copied to
242233 // / perform things like indexing. This can cause segfaults and
@@ -256,13 +247,15 @@ class MapInfoFinalizationPass
256247 mlir::omp::TargetUpdateOp>(target))
257248 return mapTypeFlag;
258249
250+ llvm::omp::OpenMPOffloadMappingFlags Always =
251+ llvm::omp::OpenMPOffloadMappingFlags (mapTypeFlag) &
252+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
259253 llvm::omp::OpenMPOffloadMappingFlags Implicit =
260254 llvm::omp::OpenMPOffloadMappingFlags (mapTypeFlag) &
261255 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
262256
263257 return llvm::to_underlying (
264- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS |
265- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | Implicit);
258+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | Always | Implicit);
266259 }
267260
268261 // / Check if the mapOp is present in the HasDeviceAddr clause on
@@ -291,6 +284,7 @@ class MapInfoFinalizationPass
291284 mlir::ArrayAttr newMembersAttr;
292285 mlir::SmallVector<mlir::Value> newMembers;
293286 llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
287+ bool IsHasDeviceAddr = isHasDeviceAddr (op, target);
294288
295289 if (!mapMemberUsers.empty () || !op.getMembers ().empty ())
296290 getMemberIndicesAsVectors (
@@ -333,7 +327,7 @@ class MapInfoFinalizationPass
333327 mapUser.parent .getMembersMutable ().assign (newMemberOps);
334328 mapUser.parent .setMembersIndexAttr (
335329 builder.create2DI64ArrayAttr (memberIndices));
336- } else if (!isHasDeviceAddr (op, target) ) {
330+ } else if (!IsHasDeviceAddr ) {
337331 auto baseAddr = genBaseAddrMap (descriptor, op.getBounds (),
338332 op.getMapType ().value_or (0 ), builder);
339333 newMembers.push_back (baseAddr);
@@ -349,6 +343,18 @@ class MapInfoFinalizationPass
349343 }
350344 }
351345
346+ // Descriptors for objects listed on the `has_device_addr` will always
347+ // be copied. This is because the descriptor can be rematerialized by the
348+ // compiler, and so the address of the descriptor for a given object at
349+ // one place in the code may differ from that address in another place.
350+ // The contents of the descriptor (the base address in particular) will
351+ // remain unchanged though.
352+ uint64_t MapType = op.getMapType ().value_or (0 );
353+ if (IsHasDeviceAddr) {
354+ MapType |= llvm::to_underlying (
355+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
356+ }
357+
352358 mlir::omp::MapInfoOp newDescParentMapOp =
353359 builder.create <mlir::omp::MapInfoOp>(
354360 op->getLoc (), op.getResult ().getType (), descriptor,
@@ -357,7 +363,7 @@ class MapInfoFinalizationPass
357363 /* bounds=*/ mlir::SmallVector<mlir::Value>{},
358364 builder.getIntegerAttr (
359365 builder.getIntegerType (64 , false ),
360- getDescriptorMapType (op. getMapType (). value_or ( 0 ) , target)),
366+ getDescriptorMapType (MapType , target)),
361367 /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getMapCaptureTypeAttr (),
362368 op.getNameAttr (),
363369 /* partial_map=*/ builder.getBoolAttr (false ));
0 commit comments