@@ -3108,10 +3108,10 @@ class ArrayExprLowering {
3108
3108
void lowerArrayAssignment (const TL &lhs, const TR &rhs) {
3109
3109
auto loc = getLoc ();
3110
3110
// / Here the target subspace is not necessarily contiguous. The ArrayUpdate
3111
- // / continuation is implicitly returned in `ccDest ` and the ArrayLoad in
3112
- // / `destination`.
3111
+ // / continuation is implicitly returned in `ccStoreToDest ` and the ArrayLoad
3112
+ // / in `destination`.
3113
3113
PushSemantics (ConstituentSemantics::ProjectedCopyInCopyOut);
3114
- ccDest = genarr (lhs);
3114
+ ccStoreToDest = genarr (lhs);
3115
3115
determineShapeOfDest (lhs);
3116
3116
semant = ConstituentSemantics::RefTransparent;
3117
3117
auto exv = lowerArrayExpression (rhs);
@@ -3143,7 +3143,7 @@ class ArrayExprLowering {
3143
3143
newIters.prependIndexValue (i);
3144
3144
return newIters;
3145
3145
};
3146
- ccDest = [=](IterSpace iters) { return lambda (pc (iters)); };
3146
+ ccStoreToDest = [=](IterSpace iters) { return lambda (pc (iters)); };
3147
3147
destShape.assign (extents.begin (), extents.end ());
3148
3148
semant = ConstituentSemantics::RefTransparent;
3149
3149
auto exv = lowerArrayExpression (rhs);
@@ -3246,7 +3246,8 @@ class ArrayExprLowering {
3246
3246
destShape, lengthParams);
3247
3247
// Create ArrayLoad for the mutable box and save it into `destination`.
3248
3248
PushSemantics (ConstituentSemantics::ProjectedCopyInCopyOut);
3249
- ccDest = genarr (fir::factory::genMutableBoxRead (builder, loc, mutableBox));
3249
+ ccStoreToDest =
3250
+ genarr (fir::factory::genMutableBoxRead (builder, loc, mutableBox));
3250
3251
// If the rhs is scalar, get shape from the allocatable ArrayLoad.
3251
3252
if (destShape.empty ())
3252
3253
destShape = getShape (destination);
@@ -3310,7 +3311,7 @@ class ArrayExprLowering {
3310
3311
3311
3312
// / Entry point into lowering an expression with rank. This entry point is for
3312
3313
// / lowering a rhs expression, for example. (RefTransparent semantics.)
3313
- static ExtValue lowerSomeNewArrayExpression (
3314
+ static ExtValue lowerNewArrayExpression (
3314
3315
Fortran::lower::AbstractConverter &converter,
3315
3316
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3316
3317
const std::optional<Fortran::evaluate::Shape> &shape,
@@ -3331,7 +3332,7 @@ class ArrayExprLowering {
3331
3332
fir::dyn_cast_ptrEleTy (tempRes.getType ()).cast <fir::SequenceType>();
3332
3333
if (auto charTy =
3333
3334
arrTy.getEleTy ().template dyn_cast <fir::CharacterType>()) {
3334
- if (charTy. getLen () <= 0 )
3335
+ if (fir::characterWithDynamicLen (charTy) )
3335
3336
TODO (loc, " CHARACTER does not have constant LEN" );
3336
3337
auto len = builder.createIntegerConstant (
3337
3338
loc, builder.getCharacterLengthType (), charTy.getLen ());
@@ -3340,6 +3341,99 @@ class ArrayExprLowering {
3340
3341
return fir::ArrayBoxValue (tempRes, dest.getExtents ());
3341
3342
}
3342
3343
3344
+ static ExtValue lowerLazyArrayExpression (
3345
+ Fortran::lower::AbstractConverter &converter,
3346
+ Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3347
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3348
+ mlir::Value var) {
3349
+ ArrayExprLowering ael{converter, stmtCtx, symMap};
3350
+ return ael.lowerLazyArrayExpression (expr, var);
3351
+ }
3352
+
3353
+ ExtValue lowerLazyArrayExpression (
3354
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3355
+ mlir::Value var) {
3356
+ auto loc = getLoc ();
3357
+ // Once the loop extents have been computed, which may require being inside
3358
+ // some explicit loops, lazily allocate the expression on the heap.
3359
+ ccPrelude = [=](llvm::ArrayRef<mlir::Value> shape) -> mlir::Value {
3360
+ auto load = builder.create <fir::LoadOp>(loc, var);
3361
+ auto eleTy = fir::unwrapRefType (load.getType ());
3362
+ auto unknown = fir::SequenceType::getUnknownExtent ();
3363
+ fir::SequenceType::Shape extents (shape.size (), unknown);
3364
+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3365
+ auto toTy = fir::HeapType::get (seqTy);
3366
+ auto castTo = builder.createConvert (loc, toTy, load);
3367
+ auto cmp = builder.genIsNull (loc, castTo);
3368
+ auto ifOp = builder.create <fir::IfOp>(loc, cmp, /* withElseRegion=*/ false );
3369
+ auto insPt = builder.saveInsertionPoint ();
3370
+ builder.setInsertionPointToStart (&ifOp.thenRegion ().front ());
3371
+ auto mem = builder.create <fir::AllocMemOp>(loc, seqTy, " .lazy.mask" ,
3372
+ llvm::None, shape);
3373
+ auto uncast = builder.createConvert (loc, load.getType (), mem);
3374
+ builder.create <fir::StoreOp>(loc, uncast, var);
3375
+ builder.restoreInsertionPoint (insPt);
3376
+ return mem;
3377
+ };
3378
+ // Create a dummy array_load before the loop. We're storing to a lazy
3379
+ // temporary, so there will be no conflict and no copy-in.
3380
+ ccLoadDest = [=](llvm::ArrayRef<mlir::Value> shape) -> fir::ArrayLoadOp {
3381
+ auto load = builder.create <fir::LoadOp>(loc, var);
3382
+ auto eleTy = fir::unwrapRefType (load.getType ());
3383
+ auto unknown = fir::SequenceType::getUnknownExtent ();
3384
+ fir::SequenceType::Shape extents (shape.size (), unknown);
3385
+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3386
+ auto toTy = fir::HeapType::get (seqTy);
3387
+ auto castTo = builder.createConvert (loc, toTy, load);
3388
+ auto shapeOp = builder.consShape (loc, shape);
3389
+ return builder.create <fir::ArrayLoadOp>(
3390
+ loc, seqTy, castTo, shapeOp, /* slice=*/ mlir::Value{}, llvm::None);
3391
+ };
3392
+ // Custom lowering of the element store to deal with the extra indirection
3393
+ // to the lazy allocated buffer.
3394
+ ccStoreToDest = [=](IterSpace iters) {
3395
+ auto load = builder.create <fir::LoadOp>(loc, var);
3396
+ auto eleTy = fir::unwrapRefType (load.getType ());
3397
+ auto unknown = fir::SequenceType::getUnknownExtent ();
3398
+ fir::SequenceType::Shape extents (iters.iterVec ().size (), unknown);
3399
+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3400
+ auto toTy = fir::HeapType::get (seqTy);
3401
+ auto castTo = builder.createConvert (loc, toTy, load);
3402
+ auto shape = builder.consShape (loc, genIterationShape ());
3403
+ auto indices = fir::factory::originateIndices (
3404
+ loc, builder, castTo.getType (), shape, iters.iterVec ());
3405
+ auto eleAddr = builder.create <fir::ArrayCoorOp>(
3406
+ loc, builder.getRefType (eleTy), castTo, shape,
3407
+ /* slice=*/ mlir::Value{}, indices, destination.typeparams ());
3408
+ auto eleVal = builder.createConvert (loc, eleTy, iters.getElement ());
3409
+ builder.create <fir::StoreOp>(loc, eleVal, eleAddr);
3410
+ return iters.innerArgument ();
3411
+ };
3412
+ auto loopRes = lowerArrayExpression (expr);
3413
+ auto load = builder.create <fir::LoadOp>(loc, var);
3414
+ auto eleTy = fir::unwrapRefType (load.getType ());
3415
+ auto unknown = fir::SequenceType::getUnknownExtent ();
3416
+ fir::SequenceType::Shape extents (genIterationShape ().size (), unknown);
3417
+ auto seqTy = fir::SequenceType::get (extents, eleTy);
3418
+ auto toTy = fir::HeapType::get (seqTy);
3419
+ auto tempRes = builder.createConvert (loc, toTy, load);
3420
+ builder.create <fir::ArrayMergeStoreOp>(
3421
+ loc, destination, fir::getBase (loopRes), tempRes, destination.slice (),
3422
+ destination.typeparams ());
3423
+ auto tempTy = fir::dyn_cast_ptrEleTy (tempRes.getType ());
3424
+ assert (tempTy && tempTy.isa <fir::SequenceType>() &&
3425
+ " must be a reference to an array" );
3426
+ auto ety = fir::unwrapSequenceType (tempTy);
3427
+ if (auto charTy = ety.dyn_cast <fir::CharacterType>()) {
3428
+ if (fir::characterWithDynamicLen (charTy))
3429
+ TODO (loc, " CHARACTER does not have constant LEN" );
3430
+ auto len = builder.createIntegerConstant (
3431
+ loc, builder.getCharacterLengthType (), charTy.getLen ());
3432
+ return fir::CharArrayBoxValue (tempRes, len, destination.getExtents ());
3433
+ }
3434
+ return fir::ArrayBoxValue (tempRes, destination.getExtents ());
3435
+ }
3436
+
3343
3437
void determineShapeOfDest (const fir::ExtendedValue &lhs) {
3344
3438
destShape = fir::factory::getExtents (builder, getLoc (), lhs);
3345
3439
}
@@ -3416,9 +3510,9 @@ class ArrayExprLowering {
3416
3510
auto innerArg = iterSpace.innerArgument ();
3417
3511
auto exv = f (iterSpace);
3418
3512
mlir::Value upd;
3419
- if (ccDest .hasValue ()) {
3513
+ if (ccStoreToDest .hasValue ()) {
3420
3514
iterSpace.setElement (std::move (exv));
3421
- upd = fir::getBase (ccDest .getValue ()(iterSpace));
3515
+ upd = fir::getBase (ccStoreToDest .getValue ()(iterSpace));
3422
3516
} else {
3423
3517
auto resTy = adjustedArrayElementType (innerArg.getType ());
3424
3518
auto element = adjustedArrayElement (loc, builder, fir::getBase (exv),
@@ -3509,6 +3603,14 @@ class ArrayExprLowering {
3509
3603
// Mask expressions are array expressions too.
3510
3604
for (const auto *e : implicitSpace->getExprs ())
3511
3605
if (e && !implicitSpace->isLowered (e)) {
3606
+ if (auto var = implicitSpace->lookupVariable (e)) {
3607
+ // Allocate the mask buffer lazily.
3608
+ auto tmp = Fortran::lower::createLazyArrayTempValue (
3609
+ converter, *e, var, symMap, stmtCtx);
3610
+ auto shape = builder.createShape (loc, tmp);
3611
+ implicitSpace->bind (e, fir::getBase (tmp), shape);
3612
+ continue ;
3613
+ }
3512
3614
auto optShape =
3513
3615
Fortran::evaluate::GetShape (converter.getFoldingContext (), *e);
3514
3616
auto tmp = Fortran::lower::createSomeArrayTempValue (
@@ -3557,6 +3659,12 @@ class ArrayExprLowering {
3557
3659
const auto loopDepth = loopUppers.size ();
3558
3660
llvm::SmallVector<mlir::Value> ivars;
3559
3661
if (loopDepth > 0 ) {
3662
+ // Generate the lazy mask allocation, if one was given.
3663
+ if (ccPrelude.hasValue ()) {
3664
+ [[maybe_unused]] auto allocMem = ccPrelude.getValue ()(shape);
3665
+ assert (allocMem && " mask buffer allocation failure" );
3666
+ }
3667
+
3560
3668
auto *startBlock = builder.getBlock ();
3561
3669
for (auto i : llvm::enumerate (llvm::reverse (loopUppers))) {
3562
3670
if (i.index () > 0 ) {
@@ -3600,7 +3708,7 @@ class ArrayExprLowering {
3600
3708
// explicit masks, which are interleaved, these mask expression appear in
3601
3709
// the innermost loop.
3602
3710
if (implicitSpaceHasMasks ()) {
3603
- auto prependAsNeeded = [&](auto &&indices) {
3711
+ auto appendAsNeeded = [&](auto &&indices) {
3604
3712
llvm::SmallVector<mlir::Value> result;
3605
3713
result.append (indices.begin (), indices.end ());
3606
3714
return result;
@@ -3614,7 +3722,7 @@ class ArrayExprLowering {
3614
3722
auto eleRefTy = builder.getRefType (eleTy);
3615
3723
auto i1Ty = builder.getI1Type ();
3616
3724
// Adjust indices for any shift of the origin of the array.
3617
- auto indexes = prependAsNeeded (fir::factory::originateIndices (
3725
+ auto indexes = appendAsNeeded (fir::factory::originateIndices (
3618
3726
loc, builder, tmp.getType (), shape, iters.iterVec ()));
3619
3727
auto addr = builder.create <fir::ArrayCoorOp>(
3620
3728
loc, eleRefTy, tmp, shape, /* slice=*/ mlir::Value{}, indexes,
@@ -3664,6 +3772,8 @@ class ArrayExprLowering {
3664
3772
fir::ArrayLoadOp
3665
3773
createAndLoadSomeArrayTemp (mlir::Type type,
3666
3774
llvm::ArrayRef<mlir::Value> shape) {
3775
+ if (ccLoadDest.hasValue ())
3776
+ return ccLoadDest.getValue ()(shape);
3667
3777
auto seqTy = type.dyn_cast <fir::SequenceType>();
3668
3778
assert (seqTy && " must be an array" );
3669
3779
auto loc = getLoc ();
@@ -4613,7 +4723,7 @@ class ArrayExprLowering {
4613
4723
auto loc = getLoc ();
4614
4724
auto memref = fir::getBase (extMemref);
4615
4725
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy (memref.getType ());
4616
- assert (arrTy.isa <fir::SequenceType>());
4726
+ assert (arrTy.isa <fir::SequenceType>() && " memory ref must be an array " );
4617
4727
auto shape = builder.createShape (loc, extMemref);
4618
4728
mlir::Value slice;
4619
4729
if (inSlice) {
@@ -4898,7 +5008,7 @@ class ArrayExprLowering {
4898
5008
if (isArray (x)) {
4899
5009
auto e = toEvExpr (x);
4900
5010
auto sh = Fortran::evaluate::GetShape (converter.getFoldingContext (), e);
4901
- return {lowerSomeNewArrayExpression (converter, symMap, stmtCtx, sh, e),
5011
+ return {lowerNewArrayExpression (converter, symMap, stmtCtx, sh, e),
4902
5012
/* needCopy=*/ true };
4903
5013
}
4904
5014
return {asScalar (x), /* needCopy=*/ true };
@@ -5429,7 +5539,11 @@ class ArrayExprLowering {
5429
5539
Fortran::lower::StatementContext &stmtCtx;
5430
5540
Fortran::lower::SymMap &symMap;
5431
5541
// / The continuation to generate code to update the destination.
5432
- llvm::Optional<CC> ccDest;
5542
+ llvm::Optional<CC> ccStoreToDest;
5543
+ llvm::Optional<std::function<mlir::Value(llvm::ArrayRef<mlir::Value>)>>
5544
+ ccPrelude;
5545
+ llvm::Optional<std::function<fir::ArrayLoadOp(llvm::ArrayRef<mlir::Value>)>>
5546
+ ccLoadDest;
5433
5547
// / The destination is the loaded array into which the results will be
5434
5548
// / merged.
5435
5549
fir::ArrayLoadOp destination;
@@ -5539,8 +5653,18 @@ fir::ExtendedValue Fortran::lower::createSomeArrayTempValue(
5539
5653
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
5540
5654
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
5541
5655
LLVM_DEBUG (expr.AsFortran (llvm::dbgs () << " array value: " ) << ' \n ' );
5542
- return ArrayExprLowering::lowerSomeNewArrayExpression (converter, symMap,
5543
- stmtCtx, shape, expr);
5656
+ return ArrayExprLowering::lowerNewArrayExpression (converter, symMap, stmtCtx,
5657
+ shape, expr);
5658
+ }
5659
+
5660
+ fir::ExtendedValue Fortran::lower::createLazyArrayTempValue (
5661
+ Fortran::lower::AbstractConverter &converter,
5662
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
5663
+ mlir::Value var, Fortran::lower::SymMap &symMap,
5664
+ Fortran::lower::StatementContext &stmtCtx) {
5665
+ LLVM_DEBUG (expr.AsFortran (llvm::dbgs () << " array value: " ) << ' \n ' );
5666
+ return ArrayExprLowering::lowerLazyArrayExpression (converter, symMap, stmtCtx,
5667
+ expr, var);
5544
5668
}
5545
5669
5546
5670
fir::ExtendedValue Fortran::lower::createSomeArrayBox (
@@ -5637,6 +5761,9 @@ void Fortran::lower::createArrayMergeStores(
5637
5761
builder.create <fir::ArrayMergeStoreOp>(
5638
5762
loc, load, i.value (), load.memref (), load.slice (), load.typeparams ());
5639
5763
}
5764
+ // Cleanup any residual mask buffers.
5765
+ esp.outermostContext ().finalize ();
5766
+ esp.outermostContext ().reset ();
5640
5767
}
5641
5768
esp.outerLoopStack .pop_back ();
5642
5769
esp.innerArgsStack .pop_back ();
0 commit comments