Skip to content

Commit 9dae79d

Browse files
committed
[Flang][OpenMP] Fix USM close semantics and use_device_ptr
- Add CLOSE map flag when USM is required. - use_device_ptr: prevent implicitly expanding member operands. - Fixes test offload/test/offloading/fortran/usm_map_close.f90.
1 parent c0fb07c commit 9dae79d

File tree

1 file changed

+124
-25
lines changed

1 file changed

+124
-25
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 124 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "mlir/IR/SymbolTable.h"
4141
#include "mlir/Pass/Pass.h"
4242
#include "mlir/Support/LLVM.h"
43+
#include "llvm/ADT/BitmaskEnum.h"
4344
#include "llvm/ADT/SmallPtrSet.h"
4445
#include "llvm/ADT/StringSet.h"
4546
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -398,24 +399,18 @@ class MapInfoFinalizationPass
398399
baseAddrIndex);
399400
}
400401

401-
/// Adjusts the descriptor's map type. The main alteration that is done
402-
/// currently is transforming the map type to `OMP_MAP_TO` where possible.
403-
/// This is because we will always need to map the descriptor to device
404-
/// (or at the very least it seems to be the case currently with the
405-
/// current lowered kernel IR), as without the appropriate descriptor
406-
/// information on the device there is a risk of the kernel IR
407-
/// requesting for various data that will not have been copied to
408-
/// perform things like indexing. This can cause segfaults and
409-
/// memory access errors. However, we do not need this data mapped
410-
/// back to the host from the device, as per the OpenMP spec we cannot alter
411-
/// the data via resizing or deletion on the device. Discarding any
412-
/// descriptor alterations via no map back is reasonable (and required
413-
/// for certain segments of descriptor data like the type descriptor that are
414-
/// global constants). This alteration is only inapplicable to `target exit`
415-
/// and `target update` currently, and that's due to `target exit` not
416-
/// allowing `to` mappings, and `target update` not allowing both `to` and
417-
/// `from` simultaneously. We currently try to maintain the `implicit` flag
418-
/// where necessary, although it does not seem strictly required.
402+
/// Adjust the descriptor's map type such that we ensure the descriptor
403+
/// itself is present on device when needed, without changing the user's
404+
/// requested data mapping semantics for the underlying data.
405+
///
406+
/// We conservatively transform descriptor mappings to `OMP_MAP_TO` (and
407+
/// preserve `IMPLICIT`/`ALWAYS` when present) for structured regions. The
408+
/// descriptor should live on device for indexing, bounds, etc., but we do
409+
/// not require, nor want, additional mapping semantics like `CLOSE` for the
410+
/// descriptor entry itself. `CLOSE` (and other user-provided flags) should
411+
/// apply to the base data entry that actually carries the pointee, which is
412+
/// generated separately as a member map. For `target exit`/`target update`
413+
/// we keep the original map type unchanged.
419414
unsigned long getDescriptorMapType(unsigned long mapTypeFlag,
420415
mlir::Operation *target) {
421416
using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
@@ -425,8 +420,24 @@ class MapInfoFinalizationPass
425420

426421
mapFlags flags = mapFlags::OMP_MAP_TO |
427422
(mapFlags(mapTypeFlag) &
428-
(mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE |
429-
mapFlags::OMP_MAP_ALWAYS));
423+
(mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS));
424+
425+
// For unified_shared_memory, we additionally add `CLOSE` on the descriptor
426+
// to ensure device-local placement where required by tests relying on USM +
427+
// close semantics.
428+
if (target) {
429+
if (auto mod = target->getParentOfType<mlir::ModuleOp>()) {
430+
if (mlir::Attribute reqAttr = mod->getAttr("omp.requires")) {
431+
if (auto req =
432+
mlir::dyn_cast<mlir::omp::ClauseRequiresAttr>(reqAttr)) {
433+
if (mlir::omp::bitEnumContainsAll(
434+
req.getValue(),
435+
mlir::omp::ClauseRequires::unified_shared_memory))
436+
flags |= mapFlags::OMP_MAP_CLOSE;
437+
}
438+
}
439+
}
440+
}
430441
return llvm::to_underlying(flags);
431442
}
432443

@@ -518,6 +529,75 @@ class MapInfoFinalizationPass
518529
return newMapInfoOp;
519530
}
520531

532+
// Expand mappings of type(C_PTR) to map their `__address` field explicitly
533+
// as a single pointer-sized member (USM-gated at callsite). This helps in
534+
// USM scenarios to ensure the pointer-sized mapping is used.
535+
mlir::omp::MapInfoOp genCptrMemberMap(mlir::omp::MapInfoOp op,
536+
fir::FirOpBuilder &builder) {
537+
if (!op.getMembers().empty())
538+
return op;
539+
540+
mlir::Type varTy = fir::unwrapRefType(op.getVarPtr().getType());
541+
if (!mlir::isa<fir::RecordType>(varTy))
542+
return op;
543+
auto recTy = mlir::cast<fir::RecordType>(varTy);
544+
// If not a builtin C_PTR record, skip.
545+
if (!recTy.getName().ends_with("__builtin_c_ptr"))
546+
return op;
547+
548+
// Find the index of the c_ptr address component named "__address".
549+
int32_t fieldIdx = recTy.getFieldIndex("__address");
550+
if (fieldIdx < 0)
551+
return op;
552+
553+
mlir::Location loc = op.getVarPtr().getLoc();
554+
mlir::Type memTy = recTy.getType(fieldIdx);
555+
fir::IntOrValue idxConst =
556+
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
557+
mlir::Value coord = fir::CoordinateOp::create(
558+
builder, loc, builder.getRefType(memTy), op.getVarPtr(),
559+
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
560+
561+
// Child for the `__address` member.
562+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}};
563+
mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx);
564+
// Force CLOSE in USM paths so the pointer gets device-local placement
565+
// when required by tests relying on USM + close semantics.
566+
uint64_t mapTypeVal =
567+
op.getMapType() |
568+
llvm::to_underlying(
569+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
570+
mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr(
571+
builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal);
572+
573+
mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create(
574+
builder, loc, coord.getType(), coord,
575+
mlir::TypeAttr::get(fir::unwrapRefType(coord.getType())), mapTypeAttr,
576+
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
577+
mlir::omp::VariableCaptureKind::ByRef),
578+
/*varPtrPtr=*/mlir::Value{},
579+
/*members=*/llvm::SmallVector<mlir::Value>{},
580+
/*member_index=*/mlir::ArrayAttr{},
581+
/*bounds=*/op.getBounds(),
582+
/*mapperId=*/mlir::FlatSymbolRefAttr(),
583+
/*name=*/op.getNameAttr(),
584+
/*partial_map=*/builder.getBoolAttr(false));
585+
586+
// Rebuild the parent as a container with the `__address` member.
587+
mlir::omp::MapInfoOp newParent = mlir::omp::MapInfoOp::create(
588+
builder, op.getLoc(), op.getResult().getType(), op.getVarPtr(),
589+
op.getVarTypeAttr(), mapTypeAttr, op.getMapCaptureTypeAttr(),
590+
/*varPtrPtr=*/mlir::Value{},
591+
/*members=*/llvm::SmallVector<mlir::Value>{memberMap},
592+
/*member_index=*/newMembersAttr,
593+
/*bounds=*/llvm::SmallVector<mlir::Value>{},
594+
/*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(),
595+
/*partial_map=*/builder.getBoolAttr(false));
596+
op.replaceAllUsesWith(newParent.getResult());
597+
op->erase();
598+
return newParent;
599+
}
600+
521601
mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
522602
fir::FirOpBuilder &builder,
523603
mlir::Operation *target) {
@@ -727,11 +807,6 @@ class MapInfoFinalizationPass
727807
argIface.getUseDeviceAddrBlockArgsStart() +
728808
argIface.numUseDeviceAddrBlockArgs());
729809

730-
mlir::MutableOperandRange useDevPtrMutableOpRange =
731-
targetDataOp.getUseDevicePtrVarsMutable();
732-
addOperands(useDevPtrMutableOpRange, target,
733-
argIface.getUseDevicePtrBlockArgsStart() +
734-
argIface.numUseDevicePtrBlockArgs());
735810
} else if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target)) {
736811
mlir::MutableOperandRange hasDevAddrMutableOpRange =
737812
targetOp.getHasDeviceAddrVarsMutable();
@@ -1169,6 +1244,30 @@ class MapInfoFinalizationPass
11691244
genBoxcharMemberMap(op, builder);
11701245
});
11711246

1247+
// Expand type(C_PTR) only when unified_shared_memory is required,
1248+
// to ensure device-visible pointer size/behavior in USM scenarios
1249+
// without changing default expectations elsewhere.
1250+
func->walk([&](mlir::omp::MapInfoOp op) {
1251+
// Check module requires USM; otherwise, leave mappings untouched.
1252+
auto mod = func->getParentOfType<mlir::ModuleOp>();
1253+
bool hasUSM = false;
1254+
if (mod) {
1255+
if (mlir::Attribute reqAttr = mod->getAttr("omp.requires")) {
1256+
if (auto req =
1257+
mlir::dyn_cast<mlir::omp::ClauseRequiresAttr>(reqAttr)) {
1258+
hasUSM = mlir::omp::bitEnumContainsAll(
1259+
req.getValue(),
1260+
mlir::omp::ClauseRequires::unified_shared_memory);
1261+
}
1262+
}
1263+
}
1264+
if (!hasUSM)
1265+
return;
1266+
1267+
builder.setInsertionPoint(op);
1268+
genCptrMemberMap(op, builder);
1269+
});
1270+
11721271
func->walk([&](mlir::omp::MapInfoOp op) {
11731272
// TODO: Currently only supports a single user for the MapInfoOp. This
11741273
// is fine for the moment, as the Fortran frontend will generate a

0 commit comments

Comments
 (0)