@@ -3532,13 +3532,36 @@ static TargetDirective getTargetDirectiveFromOp(Operation *op) {
3532
3532
3533
3533
} // namespace
3534
3534
3535
- uint64_t getArrayElementSizeInBits (LLVM::LLVMArrayType arrTy, DataLayout &dl) {
3535
+ // In certain cases, we can be provided less bounds than there are nested array
3536
+ // types, but still be provided bounds, in these cases we try to compute the
3537
+ // size up to the point of the bounds provided and then let the bounds x size
3538
+ // computation do the rest of the work. This is most common in Flang where
3539
+ // character arrays provided character lengths (C/C++ string esque), represent
3540
+ // the internal string as a byte array with the length of this string
3541
+ // unrepresented by bounds.
3542
+ uint64_t getArrayElementSizeInBits (LLVM::LLVMArrayType arrTy, DataLayout &dl,
3543
+ int boundsCount) {
3544
+ if (boundsCount == 0 )
3545
+ return dl.getTypeSizeInBits (arrTy);
3536
3546
if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3537
3547
arrTy.getElementType ()))
3538
- return getArrayElementSizeInBits (nestedArrTy, dl);
3548
+ return getArrayElementSizeInBits (nestedArrTy, dl, --boundsCount );
3539
3549
return dl.getTypeSizeInBits (arrTy.getElementType ());
3540
3550
}
3541
3551
3552
+ // It is possible for a 1-D array type to provide N-D bounds to index
3553
+ // with instead of 1-D Bounds. this is common to do with byte arrays
3554
+ // that are representing other data types, e.g. an N-D char array, we
3555
+ // support this use case.
3556
+ // TODO: Extend to just check if we have more bounds than array
3557
+ // dimensions
3558
+ static bool is1DArrayWithNDBounds (llvm::Type *type, size_t numBounds) {
3559
+ if (type->isArrayTy () && !type->getArrayElementType ()->isArrayTy () &&
3560
+ numBounds > 1 )
3561
+ return true ;
3562
+ return false ;
3563
+ }
3564
+
3542
3565
// This function calculates the size to be offloaded for a specified type, given
3543
3566
// its associated map clause (which can contain bounds information which affects
3544
3567
// the total size), this size is calculated based on the underlying element type
@@ -3553,6 +3576,8 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
3553
3576
Operation *clauseOp, llvm::Value *basePointer,
3554
3577
llvm::Type *baseType, llvm::IRBuilderBase &builder,
3555
3578
LLVM::ModuleTranslation &moduleTranslation) {
3579
+ // TODO: If the array is provably constant sized (e.g. from the type), work
3580
+ // out constant size from the type and skip the calculation from the bounds.
3556
3581
if (auto memberClause =
3557
3582
mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3558
3583
// This calculates the size to transfer based on bounds and the underlying
@@ -3582,7 +3607,10 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
3582
3607
// the size in inconsistent byte or bit format.
3583
3608
uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits (type);
3584
3609
if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3585
- underlyingTypeSzInBits = getArrayElementSizeInBits (arrTy, dl);
3610
+ if (!is1DArrayWithNDBounds (moduleTranslation.convertType (type),
3611
+ memberClause.getBounds ().size ()))
3612
+ underlyingTypeSzInBits = getArrayElementSizeInBits (
3613
+ arrTy, dl, memberClause.getBounds ().size ());
3586
3614
3587
3615
// The size in bytes x number of elements, the sizeInBytes stored is
3588
3616
// the underyling types size, e.g. if ptr<i32>, it'll be the i32's
@@ -4415,7 +4443,10 @@ createAlteredByCaptureMap(MapInfoData &mapData,
4415
4443
case omp::VariableCaptureKind::ByRef: {
4416
4444
llvm::Value *newV = mapData.Pointers [i];
4417
4445
std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset (
4418
- moduleTranslation, builder, mapData.BaseType [i]->isArrayTy (),
4446
+ moduleTranslation, builder,
4447
+ mapData.BaseType [i]->isArrayTy () &&
4448
+ !is1DArrayWithNDBounds (mapData.BaseType [i],
4449
+ mapOp.getBounds ().size ()),
4419
4450
mapOp.getBounds ());
4420
4451
if (isPtrTy)
4421
4452
newV = builder.CreateLoad (builder.getPtrTy (), newV);
0 commit comments