Skip to content

Commit 706ef7f

Browse files
committed
[Flang][OpenMP] Move handling from MLIR - LLVM-IR layer to FIR -> LLVM dialect layer
Likely a more palatable way to handle this for upstreaming.
1 parent 719ad71 commit 706ef7f

File tree

2 files changed

+71
-35
lines changed

2 files changed

+71
-35
lines changed

flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,21 @@ struct MapInfoOpConversion
5959
: public OpenMPFIROpConversion<mlir::omp::MapInfoOp> {
6060
using OpenMPFIROpConversion::OpenMPFIROpConversion;
6161

62+
mlir::omp::MapBoundsOp
63+
createBoundsForCharString(mlir::ConversionPatternRewriter &rewriter,
64+
unsigned int len, mlir::Location loc) const {
65+
mlir::Type i64Ty = rewriter.getIntegerType(64);
66+
auto lBound = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 0);
67+
auto uBoundAndExt =
68+
mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, len - 1);
69+
auto stride = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 1);
70+
auto baseLb = mlir::LLVM::ConstantOp::create(rewriter, loc, i64Ty, 1);
71+
auto mapBoundType = rewriter.getType<mlir::omp::MapBoundsType>();
72+
return mlir::omp::MapBoundsOp::create(rewriter, loc, mapBoundType, lBound,
73+
uBoundAndExt, uBoundAndExt, stride,
74+
/*strideInBytes*/ false, baseLb);
75+
}
76+
6277
llvm::LogicalResult
6378
matchAndRewrite(mlir::omp::MapInfoOp curOp, OpAdaptor adaptor,
6479
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -68,13 +83,58 @@ struct MapInfoOpConversion
6883
return mlir::failure();
6984

7085
llvm::SmallVector<mlir::NamedAttribute> newAttrs;
71-
mlir::omp::MapInfoOp newOp;
86+
mlir::omp::MapBoundsOp mapBoundsOp;
7287
for (mlir::NamedAttribute attr : curOp->getAttrs()) {
7388
if (auto typeAttr = mlir::dyn_cast<mlir::TypeAttr>(attr.getValue())) {
7489
mlir::Type newAttr;
7590
if (fir::isTypeWithDescriptor(typeAttr.getValue())) {
7691
newAttr = lowerTy().convertBoxTypeAsStruct(
7792
mlir::cast<fir::BaseBoxType>(typeAttr.getValue()));
93+
} else if (fir::isa_char_string(fir::unwrapSequenceType(
94+
fir::unwrapPassByRefType(typeAttr.getValue()))) &&
95+
!characterWithDynamicLen(
96+
fir::unwrapPassByRefType(typeAttr.getValue()))) {
97+
// Characters with a LEN param are represented as char
98+
// arrays/strings, the initial lowering doesn't generate
99+
// bounds for these, however, we require them to map the
100+
// data appropriately in the later lowering stages. This
101+
// is to prevent the need for unecessary caveats
102+
// specific to Flang. We also strip the array from the
103+
// type so that all variations of strings are treated
104+
// identically and there's no caveats or specialisations
105+
// required in the later stages. As an example, Boxed
106+
// char strings will emit a single char array no matter
107+
// the number of dimensions caused by additional array
108+
// dimensions which needs specialised for, as it differs
109+
// from the non-box variation which will emit each array
110+
// wrapping the character array, e.g. given a type of
111+
// the same dimensions, if one is boxed, the types would
112+
// end up:
113+
//
114+
// array<i8 x 16>
115+
// vs
116+
// array<10 x array< 10 x array<i8 x 16>>>
117+
//
118+
// This means we have to treat one specially in the
119+
// lowering. So we try to "canonicalize" it here.
120+
// TODO: Handle dynamic LEN characters.
121+
if (auto ct = mlir::dyn_cast_or_null<fir::CharacterType>(
122+
fir::unwrapSequenceType(typeAttr.getValue()))) {
123+
newAttr = converter->convertType(
124+
fir::unwrapSequenceType(typeAttr.getValue()));
125+
if (auto type = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(newAttr))
126+
newAttr = type.getElementType();
127+
// We do not generate for device, as MapBoundsOps are
128+
// unsupported, as they're currently unused.
129+
auto offloadMod =
130+
llvm::dyn_cast_or_null<mlir::omp::OffloadModuleInterface>(
131+
*curOp->getParentOfType<mlir::ModuleOp>());
132+
if (!offloadMod.getIsTargetDevice())
133+
mapBoundsOp = createBoundsForCharString(rewriter, ct.getLen(),
134+
curOp.getLoc());
135+
} else {
136+
newAttr = converter->convertType(typeAttr.getValue());
137+
}
78138
} else {
79139
newAttr = converter->convertType(typeAttr.getValue());
80140
}
@@ -84,8 +144,13 @@ struct MapInfoOpConversion
84144
}
85145
}
86146

87-
rewriter.replaceOpWithNewOp<mlir::omp::MapInfoOp>(
147+
auto newOp = rewriter.replaceOpWithNewOp<mlir::omp::MapInfoOp>(
88148
curOp, resTypes, adaptor.getOperands(), newAttrs);
149+
if (mapBoundsOp) {
150+
rewriter.startOpModification(newOp);
151+
newOp.getBoundsMutable().append(mlir::ValueRange{mapBoundsOp});
152+
rewriter.finalizeOpModification(newOp);
153+
}
89154

90155
return mlir::success();
91156
}

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3532,36 +3532,13 @@ static TargetDirective getTargetDirectiveFromOp(Operation *op) {
35323532

35333533
} // namespace
35343534

3535-
// In certain cases, we can be provided less bounds than there are nested array
3536-
// types, but still be provided bounds, in these cases we try to compute the
3537-
// size up to the point of the bounds provided and then let the bounds x size
3538-
// computation do the rest of the work. This is most common in Flang where
3539-
// character arrays provided character lengths (C/C++ string esque), represent
3540-
// the internal string as a byte array with the length of this string
3541-
// unrepresented by bounds.
3542-
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl,
3543-
int boundsCount) {
3544-
if (boundsCount == 0)
3545-
return dl.getTypeSizeInBits(arrTy);
3535+
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
35463536
if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
35473537
arrTy.getElementType()))
3548-
return getArrayElementSizeInBits(nestedArrTy, dl, --boundsCount);
3538+
return getArrayElementSizeInBits(nestedArrTy, dl);
35493539
return dl.getTypeSizeInBits(arrTy.getElementType());
35503540
}
35513541

3552-
// It is possible for a 1-D array type to provide N-D bounds to index
3553-
// with instead of 1-D Bounds. this is common to do with byte arrays
3554-
// that are representing other data types, e.g. an N-D char array, we
3555-
// support this use case.
3556-
// TODO: Extend to just check if we have more bounds than array
3557-
// dimensions
3558-
static bool is1DArrayWithNDBounds(llvm::Type *type, size_t numBounds) {
3559-
if (type->isArrayTy() && !type->getArrayElementType()->isArrayTy() &&
3560-
numBounds > 1)
3561-
return true;
3562-
return false;
3563-
}
3564-
35653542
// This function calculates the size to be offloaded for a specified type, given
35663543
// its associated map clause (which can contain bounds information which affects
35673544
// the total size), this size is calculated based on the underlying element type
@@ -3607,10 +3584,7 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
36073584
// the size in inconsistent byte or bit format.
36083585
uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
36093586
if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3610-
if (!is1DArrayWithNDBounds(moduleTranslation.convertType(type),
3611-
memberClause.getBounds().size()))
3612-
underlyingTypeSzInBits = getArrayElementSizeInBits(
3613-
arrTy, dl, memberClause.getBounds().size());
3587+
underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
36143588

36153589
// The size in bytes x number of elements, the sizeInBytes stored is
36163590
// the underyling types size, e.g. if ptr<i32>, it'll be the i32's
@@ -4443,10 +4417,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
44434417
case omp::VariableCaptureKind::ByRef: {
44444418
llvm::Value *newV = mapData.Pointers[i];
44454419
std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
4446-
moduleTranslation, builder,
4447-
mapData.BaseType[i]->isArrayTy() &&
4448-
!is1DArrayWithNDBounds(mapData.BaseType[i],
4449-
mapOp.getBounds().size()),
4420+
moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
44504421
mapOp.getBounds());
44514422
if (isPtrTy)
44524423
newV = builder.CreateLoad(builder.getPtrTy(), newV);

0 commit comments

Comments
 (0)