Skip to content

Commit bd2ba04

Browse files
authored
[Flang][OpenMP] Fix USM close semantics and use_device_ptr (#163258)
- Add CLOSE map flag when USM is required. - use_device_ptr: prevent implicitly expanding member operands. - Fixes test offload/test/offloading/fortran/usm_map_close.f90.
1 parent acf4e17 commit bd2ba04

File tree

2 files changed

+119
-2
lines changed

2 files changed

+119
-2
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "mlir/IR/SymbolTable.h"
4141
#include "mlir/Pass/Pass.h"
4242
#include "mlir/Support/LLVM.h"
43+
#include "llvm/ADT/BitmaskEnum.h"
4344
#include "llvm/ADT/SmallPtrSet.h"
4445
#include "llvm/ADT/StringSet.h"
4546
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -128,6 +129,17 @@ class MapInfoFinalizationPass
128129
}
129130
}
130131

132+
/// Return true if the module has an OpenMP requires clause that includes
133+
/// unified_shared_memory.
134+
static bool moduleRequiresUSM(mlir::ModuleOp module) {
135+
assert(module && "invalid module");
136+
if (auto req = module->getAttrOfType<mlir::omp::ClauseRequiresAttr>(
137+
"omp.requires"))
138+
return mlir::omp::bitEnumContainsAll(
139+
req.getValue(), mlir::omp::ClauseRequires::unified_shared_memory);
140+
return false;
141+
}
142+
131143
/// Create the member map for coordRef and append it (and its index
132144
/// path) to the provided new* vectors, if it is not already present.
133145
void appendMemberMapIfNew(
@@ -425,8 +437,12 @@ class MapInfoFinalizationPass
425437

426438
mapFlags flags = mapFlags::OMP_MAP_TO |
427439
(mapFlags(mapTypeFlag) &
428-
(mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE |
429-
mapFlags::OMP_MAP_ALWAYS));
440+
(mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS));
441+
// For unified_shared_memory, we additionally add `CLOSE` on the descriptor
442+
// to ensure device-local placement where required by tests relying on USM +
443+
// close semantics.
444+
if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>()))
445+
flags |= mapFlags::OMP_MAP_CLOSE;
430446
return llvm::to_underlying(flags);
431447
}
432448

@@ -518,6 +534,75 @@ class MapInfoFinalizationPass
518534
return newMapInfoOp;
519535
}
520536

537+
// Expand mappings of type(C_PTR) to map their `__address` field explicitly
538+
// as a single pointer-sized member (USM-gated at callsite). This helps in
539+
// USM scenarios to ensure the pointer-sized mapping is used.
540+
mlir::omp::MapInfoOp genCptrMemberMap(mlir::omp::MapInfoOp op,
541+
fir::FirOpBuilder &builder) {
542+
if (!op.getMembers().empty())
543+
return op;
544+
545+
mlir::Type varTy = fir::unwrapRefType(op.getVarPtr().getType());
546+
if (!mlir::isa<fir::RecordType>(varTy))
547+
return op;
548+
auto recTy = mlir::cast<fir::RecordType>(varTy);
549+
// If not a builtin C_PTR record, skip.
550+
if (!recTy.getName().ends_with("__builtin_c_ptr"))
551+
return op;
552+
553+
// Find the index of the c_ptr address component named "__address".
554+
int32_t fieldIdx = recTy.getFieldIndex("__address");
555+
if (fieldIdx < 0)
556+
return op;
557+
558+
mlir::Location loc = op.getVarPtr().getLoc();
559+
mlir::Type memTy = recTy.getType(fieldIdx);
560+
fir::IntOrValue idxConst =
561+
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
562+
mlir::Value coord = fir::CoordinateOp::create(
563+
builder, loc, builder.getRefType(memTy), op.getVarPtr(),
564+
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
565+
566+
// Child for the `__address` member.
567+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}};
568+
mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx);
569+
// Force CLOSE in USM paths so the pointer gets device-local placement
570+
// when required by tests relying on USM + close semantics.
571+
uint64_t mapTypeVal =
572+
op.getMapType() |
573+
llvm::to_underlying(
574+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
575+
mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr(
576+
builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal);
577+
578+
mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create(
579+
builder, loc, coord.getType(), coord,
580+
mlir::TypeAttr::get(fir::unwrapRefType(coord.getType())), mapTypeAttr,
581+
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
582+
mlir::omp::VariableCaptureKind::ByRef),
583+
/*varPtrPtr=*/mlir::Value{},
584+
/*members=*/llvm::SmallVector<mlir::Value>{},
585+
/*member_index=*/mlir::ArrayAttr{},
586+
/*bounds=*/op.getBounds(),
587+
/*mapperId=*/mlir::FlatSymbolRefAttr(),
588+
/*name=*/op.getNameAttr(),
589+
/*partial_map=*/builder.getBoolAttr(false));
590+
591+
// Rebuild the parent as a container with the `__address` member.
592+
mlir::omp::MapInfoOp newParent = mlir::omp::MapInfoOp::create(
593+
builder, op.getLoc(), op.getResult().getType(), op.getVarPtr(),
594+
op.getVarTypeAttr(), mapTypeAttr, op.getMapCaptureTypeAttr(),
595+
/*varPtrPtr=*/mlir::Value{},
596+
/*members=*/llvm::SmallVector<mlir::Value>{memberMap},
597+
/*member_index=*/newMembersAttr,
598+
/*bounds=*/llvm::SmallVector<mlir::Value>{},
599+
/*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(),
600+
/*partial_map=*/builder.getBoolAttr(false));
601+
op.replaceAllUsesWith(newParent.getResult());
602+
op->erase();
603+
return newParent;
604+
}
605+
521606
mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
522607
fir::FirOpBuilder &builder,
523608
mlir::Operation *target) {
@@ -1169,6 +1254,17 @@ class MapInfoFinalizationPass
11691254
genBoxcharMemberMap(op, builder);
11701255
});
11711256

1257+
// Expand type(C_PTR) only when unified_shared_memory is required,
1258+
// to ensure device-visible pointer size/behavior in USM scenarios
1259+
// without changing default expectations elsewhere.
1260+
func->walk([&](mlir::omp::MapInfoOp op) {
1261+
// Only expand C_PTR members when unified_shared_memory is required.
1262+
if (!moduleRequiresUSM(func->getParentOfType<mlir::ModuleOp>()))
1263+
return;
1264+
builder.setInsertionPoint(op);
1265+
genCptrMemberMap(op, builder);
1266+
});
1267+
11721268
func->walk([&](mlir::omp::MapInfoOp op) {
11731269
// TODO: Currently only supports a single user for the MapInfoOp. This
11741270
// is fine for the moment, as the Fortran frontend will generate a
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s
2+
!
3+
! Checks:
4+
! - C_PTR mappings expand to `__address` member with CLOSE under USM paths.
5+
! - use_device_ptr does not implicitly expand member operands in the clause.
6+
7+
subroutine only_cptr_use_device_ptr
8+
use iso_c_binding
9+
type(c_ptr) :: cptr
10+
integer :: i
11+
12+
!$omp target data use_device_ptr(cptr) map(tofrom: i)
13+
!$omp end target data
14+
end subroutine
15+
16+
! CHECK-LABEL: func.func @_QPonly_cptr_use_device_ptr()
17+
! CHECK: %[[I_MAP:.*]] = omp.map.info var_ptr(%{{.*}} : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "i"}
18+
! CHECK: %[[CP_MAP:.*]] = omp.map.info var_ptr(%{{.*}} : !fir.ref<!fir.type<{{.*}}__builtin_c_ptr{{.*}}>>, !fir.type<{{.*}}__builtin_c_ptr{{.*}}>) map_clauses(return_param) capture(ByRef) -> !fir.ref<!fir.type<{{.*}}__builtin_c_ptr{{.*}}>>
19+
! CHECK: omp.target_data map_entries(%[[I_MAP]] : !fir.ref<i32>) use_device_ptr(%[[CP_MAP]] -> %{{.*}} : !fir.ref<!fir.type<{{.*}}__builtin_c_ptr{{.*}}>>) {
20+
! CHECK: omp.terminator
21+
! CHECK: }

0 commit comments

Comments
 (0)