4040#include " mlir/IR/SymbolTable.h"
4141#include " mlir/Pass/Pass.h"
4242#include " mlir/Support/LLVM.h"
43+ #include " llvm/ADT/BitmaskEnum.h"
4344#include " llvm/ADT/SmallPtrSet.h"
4445#include " llvm/ADT/StringSet.h"
4546#include " llvm/Frontend/OpenMP/OMPConstants.h"
@@ -398,24 +399,18 @@ class MapInfoFinalizationPass
398399 baseAddrIndex);
399400 }
400401
401- // / Adjusts the descriptor's map type. The main alteration that is done
402- // / currently is transforming the map type to `OMP_MAP_TO` where possible.
403- // / This is because we will always need to map the descriptor to device
404- // / (or at the very least it seems to be the case currently with the
405- // / current lowered kernel IR), as without the appropriate descriptor
406- // / information on the device there is a risk of the kernel IR
407- // / requesting for various data that will not have been copied to
408- // / perform things like indexing. This can cause segfaults and
409- // / memory access errors. However, we do not need this data mapped
410- // / back to the host from the device, as per the OpenMP spec we cannot alter
411- // / the data via resizing or deletion on the device. Discarding any
412- // / descriptor alterations via no map back is reasonable (and required
413- // / for certain segments of descriptor data like the type descriptor that are
414- // / global constants). This alteration is only inapplicable to `target exit`
415- // / and `target update` currently, and that's due to `target exit` not
416- // / allowing `to` mappings, and `target update` not allowing both `to` and
417- // / `from` simultaneously. We currently try to maintain the `implicit` flag
418- // / where necessary, although it does not seem strictly required.
402+ // / Adjust the descriptor's map type such that we ensure the descriptor
403+ // / itself is present on device when needed, without changing the user's
404+ // / requested data mapping semantics for the underlying data.
405+ // /
406+ // / We conservatively transform descriptor mappings to `OMP_MAP_TO` (and
407+ // / preserve `IMPLICIT`/`ALWAYS` when present) for structured regions. The
408+ // / descriptor should live on device for indexing, bounds, etc., but we do
409+ // / not require, nor want, additional mapping semantics like `CLOSE` for the
410+ // / descriptor entry itself. `CLOSE` (and other user-provided flags) should
411+ // / apply to the base data entry that actually carries the pointee, which is
412+ // / generated separately as a member map. For `target exit`/`target update`
413+ // / we keep the original map type unchanged.
419414 unsigned long getDescriptorMapType (unsigned long mapTypeFlag,
420415 mlir::Operation *target) {
421416 using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
@@ -425,8 +420,24 @@ class MapInfoFinalizationPass
425420
426421 mapFlags flags = mapFlags::OMP_MAP_TO |
427422 (mapFlags (mapTypeFlag) &
428- (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE |
429- mapFlags::OMP_MAP_ALWAYS));
423+ (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS));
424+
425+ // For unified_shared_memory, we additionally add `CLOSE` on the descriptor
426+ // to ensure device-local placement where required by tests relying on USM +
427+ // close semantics.
428+ if (target) {
429+ if (auto mod = target->getParentOfType <mlir::ModuleOp>()) {
430+ if (mlir::Attribute reqAttr = mod->getAttr (" omp.requires" )) {
431+ if (auto req =
432+ mlir::dyn_cast<mlir::omp::ClauseRequiresAttr>(reqAttr)) {
433+ if (mlir::omp::bitEnumContainsAll (
434+ req.getValue (),
435+ mlir::omp::ClauseRequires::unified_shared_memory))
436+ flags |= mapFlags::OMP_MAP_CLOSE;
437+ }
438+ }
439+ }
440+ }
430441 return llvm::to_underlying (flags);
431442 }
432443
@@ -518,6 +529,75 @@ class MapInfoFinalizationPass
518529 return newMapInfoOp;
519530 }
520531
532+ // Expand mappings of type(C_PTR) to map their `__address` field explicitly
533+ // as a single pointer-sized member (USM-gated at callsite). This helps in
534+ // USM scenarios to ensure the pointer-sized mapping is used.
535+ mlir::omp::MapInfoOp genCptrMemberMap (mlir::omp::MapInfoOp op,
536+ fir::FirOpBuilder &builder) {
537+ if (!op.getMembers ().empty ())
538+ return op;
539+
540+ mlir::Type varTy = fir::unwrapRefType (op.getVarPtr ().getType ());
541+ if (!mlir::isa<fir::RecordType>(varTy))
542+ return op;
543+ auto recTy = mlir::cast<fir::RecordType>(varTy);
544+ // If not a builtin C_PTR record, skip.
545+ if (!recTy.getName ().ends_with (" __builtin_c_ptr" ))
546+ return op;
547+
548+ // Find the index of the c_ptr address component named "__address".
549+ int32_t fieldIdx = recTy.getFieldIndex (" __address" );
550+ if (fieldIdx < 0 )
551+ return op;
552+
553+ mlir::Location loc = op.getVarPtr ().getLoc ();
554+ mlir::Type memTy = recTy.getType (fieldIdx);
555+ fir::IntOrValue idxConst =
556+ mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
557+ mlir::Value coord = fir::CoordinateOp::create (
558+ builder, loc, builder.getRefType (memTy), op.getVarPtr (),
559+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
560+
561+ // Child for the `__address` member.
562+ llvm::SmallVector<llvm::SmallVector<int64_t >> memberIdx = {{0 }};
563+ mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr (memberIdx);
564+ // Force CLOSE in USM paths so the pointer gets device-local placement
565+ // when required by tests relying on USM + close semantics.
566+ uint64_t mapTypeVal =
567+ op.getMapType () |
568+ llvm::to_underlying (
569+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
570+ mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr (
571+ builder.getIntegerType (64 , /* isSigned=*/ false ), mapTypeVal);
572+
573+ mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create (
574+ builder, loc, coord.getType (), coord,
575+ mlir::TypeAttr::get (fir::unwrapRefType (coord.getType ())), mapTypeAttr,
576+ builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
577+ mlir::omp::VariableCaptureKind::ByRef),
578+ /* varPtrPtr=*/ mlir::Value{},
579+ /* members=*/ llvm::SmallVector<mlir::Value>{},
580+ /* member_index=*/ mlir::ArrayAttr{},
581+ /* bounds=*/ op.getBounds (),
582+ /* mapperId=*/ mlir::FlatSymbolRefAttr (),
583+ /* name=*/ op.getNameAttr (),
584+ /* partial_map=*/ builder.getBoolAttr (false ));
585+
586+ // Rebuild the parent as a container with the `__address` member.
587+ mlir::omp::MapInfoOp newParent = mlir::omp::MapInfoOp::create (
588+ builder, op.getLoc (), op.getResult ().getType (), op.getVarPtr (),
589+ op.getVarTypeAttr (), mapTypeAttr, op.getMapCaptureTypeAttr (),
590+ /* varPtrPtr=*/ mlir::Value{},
591+ /* members=*/ llvm::SmallVector<mlir::Value>{memberMap},
592+ /* member_index=*/ newMembersAttr,
593+ /* bounds=*/ llvm::SmallVector<mlir::Value>{},
594+ /* mapperId=*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
595+ /* partial_map=*/ builder.getBoolAttr (false ));
596+ op.replaceAllUsesWith (newParent.getResult ());
597+ op->erase ();
598+ return newParent;
599+ }
600+
521601 mlir::omp::MapInfoOp genDescriptorMemberMaps (mlir::omp::MapInfoOp op,
522602 fir::FirOpBuilder &builder,
523603 mlir::Operation *target) {
@@ -727,11 +807,6 @@ class MapInfoFinalizationPass
727807 argIface.getUseDeviceAddrBlockArgsStart () +
728808 argIface.numUseDeviceAddrBlockArgs ());
729809
730- mlir::MutableOperandRange useDevPtrMutableOpRange =
731- targetDataOp.getUseDevicePtrVarsMutable ();
732- addOperands (useDevPtrMutableOpRange, target,
733- argIface.getUseDevicePtrBlockArgsStart () +
734- argIface.numUseDevicePtrBlockArgs ());
735810 } else if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target)) {
736811 mlir::MutableOperandRange hasDevAddrMutableOpRange =
737812 targetOp.getHasDeviceAddrVarsMutable ();
@@ -1169,6 +1244,30 @@ class MapInfoFinalizationPass
11691244 genBoxcharMemberMap (op, builder);
11701245 });
11711246
1247+ // Expand type(C_PTR) only when unified_shared_memory is required,
1248+ // to ensure device-visible pointer size/behavior in USM scenarios
1249+ // without changing default expectations elsewhere.
1250+ func->walk ([&](mlir::omp::MapInfoOp op) {
1251+ // Check module requires USM; otherwise, leave mappings untouched.
1252+ auto mod = func->getParentOfType <mlir::ModuleOp>();
1253+ bool hasUSM = false ;
1254+ if (mod) {
1255+ if (mlir::Attribute reqAttr = mod->getAttr (" omp.requires" )) {
1256+ if (auto req =
1257+ mlir::dyn_cast<mlir::omp::ClauseRequiresAttr>(reqAttr)) {
1258+ hasUSM = mlir::omp::bitEnumContainsAll (
1259+ req.getValue (),
1260+ mlir::omp::ClauseRequires::unified_shared_memory);
1261+ }
1262+ }
1263+ }
1264+ if (!hasUSM)
1265+ return ;
1266+
1267+ builder.setInsertionPoint (op);
1268+ genCptrMemberMap (op, builder);
1269+ });
1270+
11721271 func->walk ([&](mlir::omp::MapInfoOp op) {
11731272 // TODO: Currently only supports a single user for the MapInfoOp. This
11741273 // is fine for the moment, as the Fortran frontend will generate a
0 commit comments