@@ -3108,10 +3108,10 @@ class ArrayExprLowering {
3108
3108
void lowerArrayAssignment (const TL &lhs, const TR &rhs) {
3109
3109
auto loc = getLoc ();
3110
3110
// / Here the target subspace is not necessarily contiguous. The ArrayUpdate
3111
- // / continuation is implicitly returned in `ccDest ` and the ArrayLoad in
3112
- // / `destination`.
3111
+ // / continuation is implicitly returned in `ccStoreToDest ` and the ArrayLoad
3112
+ // / in `destination`.
3113
3113
PushSemantics (ConstituentSemantics::ProjectedCopyInCopyOut);
3114
- ccDest = genarr (lhs);
3114
+ ccStoreToDest = genarr (lhs);
3115
3115
determineShapeOfDest (lhs);
3116
3116
semant = ConstituentSemantics::RefTransparent;
3117
3117
auto exv = lowerArrayExpression (rhs);
@@ -3143,7 +3143,7 @@ class ArrayExprLowering {
3143
3143
newIters.prependIndexValue (i);
3144
3144
return newIters;
3145
3145
};
3146
- ccDest = [=](IterSpace iters) { return lambda (pc (iters)); };
3146
+ ccStoreToDest = [=](IterSpace iters) { return lambda (pc (iters)); };
3147
3147
destShape.assign (extents.begin (), extents.end ());
3148
3148
semant = ConstituentSemantics::RefTransparent;
3149
3149
auto exv = lowerArrayExpression (rhs);
@@ -3246,7 +3246,8 @@ class ArrayExprLowering {
3246
3246
destShape, lengthParams);
3247
3247
// Create ArrayLoad for the mutable box and save it into `destination`.
3248
3248
PushSemantics (ConstituentSemantics::ProjectedCopyInCopyOut);
3249
- ccDest = genarr (fir::factory::genMutableBoxRead (builder, loc, mutableBox));
3249
+ ccStoreToDest =
3250
+ genarr (fir::factory::genMutableBoxRead (builder, loc, mutableBox));
3250
3251
// If the rhs is scalar, get shape from the allocatable ArrayLoad.
3251
3252
if (destShape.empty ())
3252
3253
destShape = getShape (destination);
@@ -3310,7 +3311,7 @@ class ArrayExprLowering {
3310
3311
3311
3312
// / Entry point into lowering an expression with rank. This entry point is for
3312
3313
// / lowering a rhs expression, for example. (RefTransparent semantics.)
3313
- static ExtValue lowerSomeNewArrayExpression (
3314
+ static ExtValue lowerNewArrayExpression (
3314
3315
Fortran::lower::AbstractConverter &converter,
3315
3316
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3316
3317
const std::optional<Fortran::evaluate::Shape> &shape,
@@ -3331,7 +3332,7 @@ class ArrayExprLowering {
3331
3332
fir::dyn_cast_ptrEleTy (tempRes.getType ()).cast <fir::SequenceType>();
3332
3333
if (auto charTy =
3333
3334
arrTy.getEleTy ().template dyn_cast <fir::CharacterType>()) {
3334
- if (charTy. getLen () <= 0 )
3335
+ if (fir::characterWithDynamicLen (charTy) )
3335
3336
TODO (loc, " CHARACTER does not have constant LEN" );
3336
3337
auto len = builder.createIntegerConstant (
3337
3338
loc, builder.getCharacterLengthType (), charTy.getLen ());
@@ -3340,6 +3341,98 @@ class ArrayExprLowering {
3340
3341
return fir::ArrayBoxValue (tempRes, dest.getExtents ());
3341
3342
}
3342
3343
3344
+ static ExtValue lowerLazyArrayExpression (
3345
+ Fortran::lower::AbstractConverter &converter,
3346
+ Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3347
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3348
+ mlir::Value var) {
3349
+ ArrayExprLowering ael{converter, stmtCtx, symMap};
3350
+ return ael.lowerLazyArrayExpression (expr, var);
3351
+ }
3352
+
3353
+ ExtValue lowerLazyArrayExpression (
3354
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3355
+ mlir::Value var) {
3356
+ auto loc = getLoc ();
3357
+ // Once the loop extents have been computed, which may require being inside
3358
+ // some explicit loops, lazily allocate the expression on the heap.
3359
+ ccPrelude = [=](llvm::ArrayRef<mlir::Value> shape) {
3360
+ auto load = builder.create <fir::LoadOp>(loc, var);
3361
+ auto eleTy = fir::unwrapRefType (load.getType ());
3362
+ auto unknown = fir::SequenceType::getUnknownExtent ();
3363
+ fir::SequenceType::Shape extents (shape.size (), unknown);
3364
+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3365
+ auto toTy = fir::HeapType::get (seqTy);
3366
+ auto castTo = builder.createConvert (loc, toTy, load);
3367
+ auto cmp = builder.genIsNull (loc, castTo);
3368
+ builder.genIfThen (loc, cmp)
3369
+ .genThen ([&]() {
3370
+ auto mem = builder.create <fir::AllocMemOp>(loc, seqTy, " .lazy.mask" ,
3371
+ llvm::None, shape);
3372
+ auto uncast = builder.createConvert (loc, load.getType (), mem);
3373
+ builder.create <fir::StoreOp>(loc, uncast, var);
3374
+ })
3375
+ .end ();
3376
+ };
3377
+ // Create a dummy array_load before the loop. We're storing to a lazy
3378
+ // temporary, so there will be no conflict and no copy-in.
3379
+ ccLoadDest = [=](llvm::ArrayRef<mlir::Value> shape) -> fir::ArrayLoadOp {
3380
+ auto load = builder.create <fir::LoadOp>(loc, var);
3381
+ auto eleTy = fir::unwrapRefType (load.getType ());
3382
+ auto unknown = fir::SequenceType::getUnknownExtent ();
3383
+ fir::SequenceType::Shape extents (shape.size (), unknown);
3384
+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3385
+ auto toTy = fir::HeapType::get (seqTy);
3386
+ auto castTo = builder.createConvert (loc, toTy, load);
3387
+ auto shapeOp = builder.consShape (loc, shape);
3388
+ return builder.create <fir::ArrayLoadOp>(
3389
+ loc, seqTy, castTo, shapeOp, /* slice=*/ mlir::Value{}, llvm::None);
3390
+ };
3391
+ // Custom lowering of the element store to deal with the extra indirection
3392
+ // to the lazy allocated buffer.
3393
+ ccStoreToDest = [=](IterSpace iters) {
3394
+ auto load = builder.create <fir::LoadOp>(loc, var);
3395
+ auto eleTy = fir::unwrapRefType (load.getType ());
3396
+ auto unknown = fir::SequenceType::getUnknownExtent ();
3397
+ fir::SequenceType::Shape extents (iters.iterVec ().size (), unknown);
3398
+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3399
+ auto toTy = fir::HeapType::get (seqTy);
3400
+ auto castTo = builder.createConvert (loc, toTy, load);
3401
+ auto shape = builder.consShape (loc, genIterationShape ());
3402
+ auto indices = fir::factory::originateIndices (
3403
+ loc, builder, castTo.getType (), shape, iters.iterVec ());
3404
+ auto eleAddr = builder.create <fir::ArrayCoorOp>(
3405
+ loc, builder.getRefType (eleTy), castTo, shape,
3406
+ /* slice=*/ mlir::Value{}, indices, destination.typeparams ());
3407
+ auto eleVal = builder.createConvert (loc, eleTy, iters.getElement ());
3408
+ builder.create <fir::StoreOp>(loc, eleVal, eleAddr);
3409
+ return iters.innerArgument ();
3410
+ };
3411
+ auto loopRes = lowerArrayExpression (expr);
3412
+ auto load = builder.create <fir::LoadOp>(loc, var);
3413
+ auto eleTy = fir::unwrapRefType (load.getType ());
3414
+ auto unknown = fir::SequenceType::getUnknownExtent ();
3415
+ fir::SequenceType::Shape extents (genIterationShape ().size (), unknown);
3416
+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3417
+ auto toTy = fir::HeapType::get (seqTy);
3418
+ auto tempRes = builder.createConvert (loc, toTy, load);
3419
+ builder.create <fir::ArrayMergeStoreOp>(
3420
+ loc, destination, fir::getBase (loopRes), tempRes, destination.slice (),
3421
+ destination.typeparams ());
3422
+ auto tempTy = fir::dyn_cast_ptrEleTy (tempRes.getType ());
3423
+ assert (tempTy && tempTy.isa <fir::SequenceType>() &&
3424
+ " must be a reference to an array" );
3425
+ auto ety = fir::unwrapSequenceType (tempTy);
3426
+ if (auto charTy = ety.dyn_cast <fir::CharacterType>()) {
3427
+ if (fir::characterWithDynamicLen (charTy))
3428
+ TODO (loc, " CHARACTER does not have constant LEN" );
3429
+ auto len = builder.createIntegerConstant (
3430
+ loc, builder.getCharacterLengthType (), charTy.getLen ());
3431
+ return fir::CharArrayBoxValue (tempRes, len, destination.getExtents ());
3432
+ }
3433
+ return fir::ArrayBoxValue (tempRes, destination.getExtents ());
3434
+ }
3435
+
3343
3436
void determineShapeOfDest (const fir::ExtendedValue &lhs) {
3344
3437
destShape = fir::factory::getExtents (builder, getLoc (), lhs);
3345
3438
}
@@ -3416,9 +3509,9 @@ class ArrayExprLowering {
3416
3509
auto innerArg = iterSpace.innerArgument ();
3417
3510
auto exv = f (iterSpace);
3418
3511
mlir::Value upd;
3419
- if (ccDest .hasValue ()) {
3512
+ if (ccStoreToDest .hasValue ()) {
3420
3513
iterSpace.setElement (std::move (exv));
3421
- upd = fir::getBase (ccDest .getValue ()(iterSpace));
3514
+ upd = fir::getBase (ccStoreToDest .getValue ()(iterSpace));
3422
3515
} else {
3423
3516
auto resTy = adjustedArrayElementType (innerArg.getType ());
3424
3517
auto element = adjustedArrayElement (loc, builder, fir::getBase (exv),
@@ -3509,6 +3602,14 @@ class ArrayExprLowering {
3509
3602
// Mask expressions are array expressions too.
3510
3603
for (const auto *e : implicitSpace->getExprs ())
3511
3604
if (e && !implicitSpace->isLowered (e)) {
3605
+ if (auto var = implicitSpace->lookupMaskVariable (e)) {
3606
+ // Allocate the mask buffer lazily.
3607
+ auto tmp = Fortran::lower::createLazyArrayTempValue (
3608
+ converter, *e, var, symMap, stmtCtx);
3609
+ auto shape = builder.createShape (loc, tmp);
3610
+ implicitSpace->bind (e, fir::getBase (tmp), shape);
3611
+ continue ;
3612
+ }
3512
3613
auto optShape =
3513
3614
Fortran::evaluate::GetShape (converter.getFoldingContext (), *e);
3514
3615
auto tmp = Fortran::lower::createSomeArrayTempValue (
@@ -3557,6 +3658,10 @@ class ArrayExprLowering {
3557
3658
const auto loopDepth = loopUppers.size ();
3558
3659
llvm::SmallVector<mlir::Value> ivars;
3559
3660
if (loopDepth > 0 ) {
3661
+ // Generate the lazy mask allocation, if one was given.
3662
+ if (ccPrelude.hasValue ())
3663
+ ccPrelude.getValue ()(shape);
3664
+
3560
3665
auto *startBlock = builder.getBlock ();
3561
3666
for (auto i : llvm::enumerate (llvm::reverse (loopUppers))) {
3562
3667
if (i.index () > 0 ) {
@@ -3600,7 +3705,7 @@ class ArrayExprLowering {
3600
3705
// explicit masks, which are interleaved, these mask expression appear in
3601
3706
// the innermost loop.
3602
3707
if (implicitSpaceHasMasks ()) {
3603
- auto prependAsNeeded = [&](auto &&indices) {
3708
+ auto appendAsNeeded = [&](auto &&indices) {
3604
3709
llvm::SmallVector<mlir::Value> result;
3605
3710
result.append (indices.begin (), indices.end ());
3606
3711
return result;
@@ -3614,7 +3719,7 @@ class ArrayExprLowering {
3614
3719
auto eleRefTy = builder.getRefType (eleTy);
3615
3720
auto i1Ty = builder.getI1Type ();
3616
3721
// Adjust indices for any shift of the origin of the array.
3617
- auto indexes = prependAsNeeded (fir::factory::originateIndices (
3722
+ auto indexes = appendAsNeeded (fir::factory::originateIndices (
3618
3723
loc, builder, tmp.getType (), shape, iters.iterVec ()));
3619
3724
auto addr = builder.create <fir::ArrayCoorOp>(
3620
3725
loc, eleRefTy, tmp, shape, /* slice=*/ mlir::Value{}, indexes,
@@ -3664,6 +3769,8 @@ class ArrayExprLowering {
3664
3769
fir::ArrayLoadOp
3665
3770
createAndLoadSomeArrayTemp (mlir::Type type,
3666
3771
llvm::ArrayRef<mlir::Value> shape) {
3772
+ if (ccLoadDest.hasValue ())
3773
+ return ccLoadDest.getValue ()(shape);
3667
3774
auto seqTy = type.dyn_cast <fir::SequenceType>();
3668
3775
assert (seqTy && " must be an array" );
3669
3776
auto loc = getLoc ();
@@ -4613,7 +4720,7 @@ class ArrayExprLowering {
4613
4720
auto loc = getLoc ();
4614
4721
auto memref = fir::getBase (extMemref);
4615
4722
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy (memref.getType ());
4616
- assert (arrTy.isa <fir::SequenceType>());
4723
+ assert (arrTy.isa <fir::SequenceType>() && " memory ref must be an array " );
4617
4724
auto shape = builder.createShape (loc, extMemref);
4618
4725
mlir::Value slice;
4619
4726
if (inSlice) {
@@ -4898,7 +5005,7 @@ class ArrayExprLowering {
4898
5005
if (isArray (x)) {
4899
5006
auto e = toEvExpr (x);
4900
5007
auto sh = Fortran::evaluate::GetShape (converter.getFoldingContext (), e);
4901
- return {lowerSomeNewArrayExpression (converter, symMap, stmtCtx, sh, e),
5008
+ return {lowerNewArrayExpression (converter, symMap, stmtCtx, sh, e),
4902
5009
/* needCopy=*/ true };
4903
5010
}
4904
5011
return {asScalar (x), /* needCopy=*/ true };
@@ -5429,7 +5536,10 @@ class ArrayExprLowering {
5429
5536
Fortran::lower::StatementContext &stmtCtx;
5430
5537
Fortran::lower::SymMap &symMap;
5431
5538
// / The continuation to generate code to update the destination.
5432
- llvm::Optional<CC> ccDest;
5539
+ llvm::Optional<CC> ccStoreToDest;
5540
+ llvm::Optional<std::function<void (llvm::ArrayRef<mlir::Value>)>> ccPrelude;
5541
+ llvm::Optional<std::function<fir::ArrayLoadOp(llvm::ArrayRef<mlir::Value>)>>
5542
+ ccLoadDest;
5433
5543
// / The destination is the loaded array into which the results will be
5434
5544
// / merged.
5435
5545
fir::ArrayLoadOp destination;
@@ -5539,8 +5649,18 @@ fir::ExtendedValue Fortran::lower::createSomeArrayTempValue(
5539
5649
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
5540
5650
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
5541
5651
LLVM_DEBUG (expr.AsFortran (llvm::dbgs () << " array value: " ) << ' \n ' );
5542
- return ArrayExprLowering::lowerSomeNewArrayExpression (converter, symMap,
5543
- stmtCtx, shape, expr);
5652
+ return ArrayExprLowering::lowerNewArrayExpression (converter, symMap, stmtCtx,
5653
+ shape, expr);
5654
+ }
5655
+
5656
+ fir::ExtendedValue Fortran::lower::createLazyArrayTempValue (
5657
+ Fortran::lower::AbstractConverter &converter,
5658
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
5659
+ mlir::Value var, Fortran::lower::SymMap &symMap,
5660
+ Fortran::lower::StatementContext &stmtCtx) {
5661
+ LLVM_DEBUG (expr.AsFortran (llvm::dbgs () << " array value: " ) << ' \n ' );
5662
+ return ArrayExprLowering::lowerLazyArrayExpression (converter, symMap, stmtCtx,
5663
+ expr, var);
5544
5664
}
5545
5665
5546
5666
fir::ExtendedValue Fortran::lower::createSomeArrayBox (
@@ -5637,6 +5757,9 @@ void Fortran::lower::createArrayMergeStores(
5637
5757
builder.create <fir::ArrayMergeStoreOp>(
5638
5758
loc, load, i.value (), load.memref (), load.slice (), load.typeparams ());
5639
5759
}
5760
+ // Cleanup any residual mask buffers.
5761
+ esp.outermostContext ().finalize ();
5762
+ esp.outermostContext ().reset ();
5640
5763
}
5641
5764
esp.outerLoopStack .pop_back ();
5642
5765
esp.innerArgsStack .pop_back ();
0 commit comments