Skip to content

Commit 74d7e99

Browse files
authored
Fix for SWDEV-524638, do not map descriptors for enter/exit/update (llvm#2220)
2 parents 502d3da + e400d14 commit 74d7e99

File tree

3 files changed

+98
-5
lines changed

3 files changed

+98
-5
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,66 @@ class MapInfoFinalizationPass
544544
return nullptr;
545545
}
546546

547+
void removeTopLevelDescriptor(mlir::omp::MapInfoOp op,
548+
fir::FirOpBuilder &builder,
549+
mlir::Operation *target) {
550+
if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
551+
mlir::omp::DeclareMapperInfoOp>(target))
552+
return;
553+
554+
// if we're not a top level descriptor with members (e.g. member of a
555+
// derived type), we do not want to perform this step.
556+
if (op.getMembers().empty())
557+
return;
558+
559+
mlir::SmallVector<mlir::Value> newMembers = op.getMembers();
560+
mlir::omp::MapInfoOp baseAddr =
561+
mlir::dyn_cast_or_null<mlir::omp::MapInfoOp>(
562+
newMembers.front().getDefiningOp());
563+
assert(baseAddr && "Expected member to be MapInfoOp");
564+
newMembers.erase(newMembers.begin());
565+
566+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
567+
getMemberIndicesAsVectors(op, memberIndices);
568+
569+
// Can skip the extra processing if there's only 1 member as it'd
570+
// be the base addresses, which we're promoting to the parent.
571+
mlir::ArrayAttr newMembersAttr;
572+
if (memberIndices.size() > 1) {
573+
memberIndices.erase(memberIndices.begin());
574+
newMembersAttr = builder.create2DI64ArrayAttr(memberIndices);
575+
}
576+
577+
// VarPtrPtr is tied to detecting if something is a pointer in the later
578+
// lowering currently, this at the moment comes tied with
579+
// OMP_MAP_PTR_AND_OBJ being applied which breaks the problem this tries to
580+
// solve by emitting a 8-byte mapping tied to the descriptor address (even
581+
// if we only emit a single map). So we circumvent this by removing the
582+
// varPtrPtr mapping, however, a side affect of this is we lose the
583+
// additional load from the backend tied to this which is required for
584+
// correctness and getting the correct address of the data to perform our
585+
// mapping. So we do our load at this stage.
586+
// TODO/FIXME: Tidy up the OMP_MAP_PTR_AND_OBJ and varPtrPtr being tied to
587+
// if something is a pointer to try and tidy up the implementation a bit.
588+
// This is an unfortunate complexity from push-back from upstream. We
589+
// could also emit a load at this level for all base addresses as well,
590+
// which in turn will simplify the later lowering a bit as well. But first
591+
// need to see how well this alteration works.
592+
auto loadBaseAddr =
593+
builder.loadIfRef(op->getLoc(), baseAddr.getVarPtrPtr());
594+
mlir::omp::MapInfoOp newBaseAddrMapOp =
595+
builder.create<mlir::omp::MapInfoOp>(
596+
op->getLoc(), loadBaseAddr.getType(), loadBaseAddr,
597+
baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(),
598+
baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, newMembers,
599+
newMembersAttr, baseAddr.getBounds(),
600+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
601+
/*partial_map=*/builder.getBoolAttr(false));
602+
op.replaceAllUsesWith(newBaseAddrMapOp.getResult());
603+
op->erase();
604+
baseAddr.erase();
605+
}
606+
547607
// This pass executes on omp::MapInfoOp's containing descriptor based types
548608
// (allocatables, pointers, assumed shape etc.) and expanding them into
549609
// multiple omp::MapInfoOp's for each pointer member contained within the
@@ -791,6 +851,38 @@ class MapInfoFinalizationPass
791851
}
792852
});
793853

854+
// Now that we've expanded all of our boxes into a descriptor and base
855+
// address map where necessary, we check if the map owner is an
856+
// enter/exit/target data directive, and if they are we drop the initial
857+
// descriptor (top-level parent) and replace it with the
858+
// base_address/data.
859+
//
860+
// This circumvents issues with stack allocated descriptors bound to
861+
// device colliding which in Flang is rather trivial for a user to do by
862+
// accident due to the rather pervasive local intermediate descriptor
863+
// generation that occurs whenever you pass boxes around different scopes.
864+
// In OpenMP 6+ mapping these would be a user error as the tools required
865+
// to circumvent these issues are provided by the spec (ref_ptr/ptee map
866+
// types), but in prior specifications these tools are not available and
867+
// it becomes an implementation issue for us to solve.
868+
//
869+
// We do this by dropping the top-level descriptor which will be the stack
870+
// descriptor when we perform enter/exit maps, as we don't want these to
871+
// be bound until necessary which is when we utilise the descriptor type
872+
// within a target region. At which point we map the relevant descriptor
873+
// data and the runtime should correctly associate the data with the
874+
// descriptor and bind together and allow clean mapping and execution.
875+
func->walk([&](mlir::omp::MapInfoOp op) {
876+
if (fir::isTypeWithDescriptor(op.getVarType()) ||
877+
mlir::isa_and_present<fir::BoxAddrOp>(
878+
op.getVarPtr().getDefiningOp())) {
879+
mlir::Operation *targetUser = getFirstTargetUser(op);
880+
assert(targetUser && "expected user of map operation was not found");
881+
builder.setInsertionPoint(op);
882+
removeTopLevelDescriptor(op, builder, targetUser);
883+
}
884+
});
885+
794886
// Wait until after we have generated all of our maps to add them onto
795887
// the target's block arguments, simplifying the process as there would be
796888
// no need to avoid accidental duplicate additions.

flang/test/Lower/volatile-openmp.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
! CHECK: %[[VAL_34:.*]] = arith.subi %[[VAL_33]]#1, %[[VAL_1]] : index
4747
! CHECK: %[[VAL_35:.*]] = omp.map.bounds lower_bound(%[[VAL_0]] : index) upper_bound(%[[VAL_34]] : index) extent(%[[VAL_33]]#1 : index) stride(%[[VAL_33]]#2 : index) start_idx(%[[VAL_32]]#0 : index) {stride_in_bytes = true}
4848
! CHECK: %[[VAL_36:.*]] = fir.box_offset %[[VAL_10]]#1 base_addr : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>, volatile>, volatile>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
49-
! CHECK: %[[VAL_37:.*]] = omp.map.info var_ptr(%[[VAL_10]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>, volatile>, volatile>, i32) map_clauses(descriptor_base_addr, to) capture(ByRef) var_ptr_ptr(%[[VAL_36]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) bounds(%[[VAL_35]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
50-
! CHECK: %[[VAL_38:.*]] = omp.map.info var_ptr(%[[VAL_10]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>, volatile>, volatile>, !fir.box<!fir.ptr<!fir.array<?xi32>>, volatile>) map_clauses(always, descriptor, to) capture(ByRef) members(%[[VAL_37]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>, volatile>, volatile> {name = "array1"}
51-
! CHECK: omp.target_enter_data map_entries(%[[VAL_38]], %[[VAL_37]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>, volatile>, volatile>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>)
49+
! CHECK: %[[VAL_37:.*]] = fir.load %[[VAL_36]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
50+
! CHECK: %[[VAL_38:.*]] = omp.map.info var_ptr(%[[VAL_37]] : !fir.ref<!fir.array<?xi32>>, i32) map_clauses(descriptor_base_addr, to) capture(ByRef) bounds(%[[VAL_35]]) -> !fir.ref<!fir.array<?xi32>> {name = "array1"}
51+
! CHECK: omp.target_enter_data map_entries(%[[VAL_38]] : !fir.ref<!fir.array<?xi32>>)
5252
! CHECK: return
5353
! CHECK: }

flang/test/Transforms/omp-map-info-finalization.fir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,9 @@ func.func @_QPreuse_alloca(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name =
325325
// CHECK: %{{[0-9]+}} = omp.map.info var_ptr(%[[ALLOCA]]
326326
// CHECK: %{{[0-9]+}} = omp.map.info var_ptr(%[[ALLOCA]]
327327
// CHECK: omp.target_data map_entries
328-
// CHECK: %{{[0-9]+}} = omp.map.info var_ptr(%[[ALLOCA]]
329-
// CHECK: %{{[0-9]+}} = omp.map.info var_ptr(%[[ALLOCA]]
328+
// CHECK: %[[BOX_OFFSET:.*]] = fir.box_offset %[[ALLOCA]]
329+
// CHECK: %[[LOAD_OFFSET:.*]] = fir.load %[[BOX_OFFSET]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xf64>>>
330+
// CHECK: %{{[0-9]+}} = omp.map.info var_ptr(%[[LOAD_OFFSET]]
330331
// CHECK: omp.target_update map_entries
331332
// CHECK: omp.terminator
332333
// CHECK: }

0 commit comments

Comments
 (0)