@@ -3532,13 +3532,36 @@ static TargetDirective getTargetDirectiveFromOp(Operation *op) {
35323532
35333533} // namespace
35343534
3535- uint64_t getArrayElementSizeInBits (LLVM::LLVMArrayType arrTy, DataLayout &dl) {
3535+ // In certain cases, we can be provided less bounds than there are nested array
3536+ // types, but still be provided bounds, in these cases we try to compute the
3537+ // size up to the point of the bounds provided and then let the bounds x size
3538+ // computation do the rest of the work. This is most common in Flang where
3539+ // character arrays provided character lengths (C/C++ string esque), represent
3540+ // the internal string as a byte array with the length of this string
3541+ // unrepresented by bounds.
3542+ uint64_t getArrayElementSizeInBits (LLVM::LLVMArrayType arrTy, DataLayout &dl,
3543+ int boundsCount) {
3544+ if (boundsCount == 0 )
3545+ return dl.getTypeSizeInBits (arrTy);
35363546 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
35373547 arrTy.getElementType ()))
3538- return getArrayElementSizeInBits (nestedArrTy, dl);
3548+ return getArrayElementSizeInBits (nestedArrTy, dl, --boundsCount );
35393549 return dl.getTypeSizeInBits (arrTy.getElementType ());
35403550}
35413551
3552+ // It is possible for a 1-D array type to provide N-D bounds to index
3553+ // with instead of 1-D Bounds. this is common to do with byte arrays
3554+ // that are representing other data types, e.g. an N-D char array, we
3555+ // support this use case.
3556+ // TODO: Extend to just check if we have more bounds than array
3557+ // dimensions
3558+ static bool is1DArrayWithNDBounds (llvm::Type *type, size_t numBounds) {
3559+ if (type->isArrayTy () && !type->getArrayElementType ()->isArrayTy () &&
3560+ numBounds > 1 )
3561+ return true ;
3562+ return false ;
3563+ }
3564+
35423565// This function calculates the size to be offloaded for a specified type, given
35433566// its associated map clause (which can contain bounds information which affects
35443567// the total size), this size is calculated based on the underlying element type
@@ -3553,6 +3576,8 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
35533576 Operation *clauseOp, llvm::Value *basePointer,
35543577 llvm::Type *baseType, llvm::IRBuilderBase &builder,
35553578 LLVM::ModuleTranslation &moduleTranslation) {
3579+ // TODO: If the array is provably constant sized (e.g. from the type), work
3580+ // out constant size from the type and skip the calculation from the bounds.
35563581 if (auto memberClause =
35573582 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
35583583 // This calculates the size to transfer based on bounds and the underlying
@@ -3582,7 +3607,10 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
35823607 // the size in inconsistent byte or bit format.
35833608 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits (type);
35843609 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3585- underlyingTypeSzInBits = getArrayElementSizeInBits (arrTy, dl);
3610+ if (!is1DArrayWithNDBounds (moduleTranslation.convertType (type),
3611+ memberClause.getBounds ().size ()))
3612+ underlyingTypeSzInBits = getArrayElementSizeInBits (
3613+ arrTy, dl, memberClause.getBounds ().size ());
35863614
35873615 // The size in bytes x number of elements, the sizeInBytes stored is
35883616 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
@@ -4415,7 +4443,10 @@ createAlteredByCaptureMap(MapInfoData &mapData,
44154443 case omp::VariableCaptureKind::ByRef: {
44164444 llvm::Value *newV = mapData.Pointers [i];
44174445 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset (
4418- moduleTranslation, builder, mapData.BaseType [i]->isArrayTy (),
4446+ moduleTranslation, builder,
4447+ mapData.BaseType [i]->isArrayTy () &&
4448+ !is1DArrayWithNDBounds (mapData.BaseType [i],
4449+ mapOp.getBounds ().size ()),
44194450 mapOp.getBounds ());
44204451 if (isPtrTy)
44214452 newV = builder.CreateLoad (builder.getPtrTy (), newV);
0 commit comments