Skip to content

Commit 3ef51bf

Browse files
authored
Merge pull request #1049 from schweitzpgi/ch-fw-lazy
Move the scope of the mask buffers to just outside the outermost FORA…
2 parents df66885 + f1cecf5 commit 3ef51bf

File tree

7 files changed

+402
-175
lines changed

7 files changed

+402
-175
lines changed

flang/include/flang/Lower/ConvertExpr.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,15 @@ createSomeArrayTempValue(AbstractConverter &converter,
170170
const evaluate::Expr<evaluate::SomeType> &expr,
171171
SymMap &symMap, StatementContext &stmtCtx);
172172

173+
/// Like createSomeArrayTempValue, but the temporary buffer is allocated lazily
174+
/// (inside the loops instead of before the loops). This can be useful if a
175+
/// loop's bounds are functions of other loop indices, for example.
176+
fir::ExtendedValue
177+
createLazyArrayTempValue(AbstractConverter &converter,
178+
const evaluate::Expr<evaluate::SomeType> &expr,
179+
mlir::Value var, SymMap &symMap,
180+
StatementContext &stmtCtx);
181+
173182
/// Lower an array expression to a value of type box. The expression must be a
174183
/// variable.
175184
fir::ExtendedValue

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ class FirOpBuilder : public mlir::OpBuilder {
336336

337337
/// Generate code testing \p addr is not a null address.
338338
mlir::Value genIsNotNull(mlir::Location loc, mlir::Value addr);
339+
340+
/// Generate code testing \p addr is a null address.
341+
mlir::Value genIsNull(mlir::Location loc, mlir::Value addr);
339342

340343
private:
341344
const KindMapping &kindMap;

flang/lib/Lower/Bridge.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
#include "flang/Optimizer/Builder/BoxValue.h"
3232
#include "flang/Optimizer/Builder/Character.h"
3333
#include "flang/Optimizer/Builder/FIRBuilder.h"
34+
#include "flang/Optimizer/Builder/Runtime/Character.h"
3435
#include "flang/Optimizer/Dialect/FIRAttr.h"
3536
#include "flang/Optimizer/Dialect/FIRDialect.h"
3637
#include "flang/Optimizer/Dialect/FIROps.h"
37-
#include "flang/Optimizer/Builder/Runtime/Character.h"
3838
#include "flang/Optimizer/Support/FIRContext.h"
3939
#include "flang/Optimizer/Support/FatalError.h"
4040
#include "flang/Optimizer/Support/InternalNames.h"
@@ -2578,8 +2578,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25782578
analyzeExplicitSpace(e.operator->());
25792579
}
25802580
void analyzeExplicitSpace(const Fortran::parser::WhereConstructStmt &ws) {
2581-
analyzeExplicitSpace(*Fortran::semantics::GetExpr(
2582-
std::get<Fortran::parser::LogicalExpr>(ws.t)));
2581+
auto *exp = Fortran::semantics::GetExpr(
2582+
std::get<Fortran::parser::LogicalExpr>(ws.t));
2583+
addMaskVariable(exp);
2584+
analyzeExplicitSpace(*exp);
25832585
}
25842586
void analyzeExplicitSpace(
25852587
const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) {
@@ -2602,8 +2604,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26022604
body.u);
26032605
}
26042606
void analyzeExplicitSpace(const Fortran::parser::MaskedElsewhereStmt &stmt) {
2605-
analyzeExplicitSpace(*Fortran::semantics::GetExpr(
2606-
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
2607+
auto *exp = Fortran::semantics::GetExpr(
2608+
std::get<Fortran::parser::LogicalExpr>(stmt.t));
2609+
addMaskVariable(exp);
2610+
analyzeExplicitSpace(*exp);
26072611
}
26082612
void
26092613
analyzeExplicitSpace(const Fortran::parser::WhereConstruct::Elsewhere *ew) {
@@ -2612,8 +2616,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26122616
analyzeExplicitSpace(e);
26132617
}
26142618
void analyzeExplicitSpace(const Fortran::parser::WhereStmt &stmt) {
2615-
analyzeExplicitSpace(*Fortran::semantics::GetExpr(
2616-
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
2619+
auto *exp = Fortran::semantics::GetExpr(
2620+
std::get<Fortran::parser::LogicalExpr>(stmt.t));
2621+
addMaskVariable(exp);
2622+
analyzeExplicitSpace(*exp);
26172623
const auto &assign =
26182624
std::get<Fortran::parser::AssignmentStmt>(stmt.t).typedAssignment->v;
26192625
assert(assign.has_value() && "WHERE has no statement");
@@ -2659,8 +2665,29 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26592665
}
26602666
analyzeExplicitSpacePop();
26612667
}
2668+
26622669
void analyzeExplicitSpacePop() { explicitIterSpace.popLevel(); }
26632670

2671+
void addMaskVariable(Fortran::lower::FrontEndExpr exp) {
2672+
// Note: use i8 to store bool values. This avoids round-down behavior found
2673+
// with sequences of i1. That is, an array of i1 will be truncated in size
2674+
// and be too small. For example, a buffer of type fir.array<7xi1> will have
2675+
// 0 size.
2676+
auto ty = fir::HeapType::get(builder->getIntegerType(8));
2677+
auto loc = toLocation();
2678+
auto var = builder->createTemporary(loc, ty);
2679+
auto nil = builder->createNullConstant(loc, ty);
2680+
builder->create<fir::StoreOp>(loc, nil, var);
2681+
implicitIterSpace.addMaskVariable(exp, var);
2682+
explicitIterSpace.outermostContext().attachCleanup([=]() {
2683+
auto load = builder->create<fir::LoadOp>(loc, var);
2684+
auto cmp = builder->genIsNotNull(loc, load);
2685+
builder->genIfThen(loc, cmp)
2686+
.genThen([&]() { builder->create<fir::FreeMemOp>(loc, load); })
2687+
.end();
2688+
});
2689+
}
2690+
26642691
//===--------------------------------------------------------------------===//
26652692

26662693
Fortran::lower::LoweringBridge &bridge;

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 139 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,10 +3108,10 @@ class ArrayExprLowering {
31083108
void lowerArrayAssignment(const TL &lhs, const TR &rhs) {
31093109
auto loc = getLoc();
31103110
/// Here the target subspace is not necessarily contiguous. The ArrayUpdate
3111-
/// continuation is implicitly returned in `ccDest` and the ArrayLoad in
3112-
/// `destination`.
3111+
/// continuation is implicitly returned in `ccStoreToDest` and the ArrayLoad
3112+
/// in `destination`.
31133113
PushSemantics(ConstituentSemantics::ProjectedCopyInCopyOut);
3114-
ccDest = genarr(lhs);
3114+
ccStoreToDest = genarr(lhs);
31153115
determineShapeOfDest(lhs);
31163116
semant = ConstituentSemantics::RefTransparent;
31173117
auto exv = lowerArrayExpression(rhs);
@@ -3143,7 +3143,7 @@ class ArrayExprLowering {
31433143
newIters.prependIndexValue(i);
31443144
return newIters;
31453145
};
3146-
ccDest = [=](IterSpace iters) { return lambda(pc(iters)); };
3146+
ccStoreToDest = [=](IterSpace iters) { return lambda(pc(iters)); };
31473147
destShape.assign(extents.begin(), extents.end());
31483148
semant = ConstituentSemantics::RefTransparent;
31493149
auto exv = lowerArrayExpression(rhs);
@@ -3246,7 +3246,8 @@ class ArrayExprLowering {
32463246
destShape, lengthParams);
32473247
// Create ArrayLoad for the mutable box and save it into `destination`.
32483248
PushSemantics(ConstituentSemantics::ProjectedCopyInCopyOut);
3249-
ccDest = genarr(fir::factory::genMutableBoxRead(builder, loc, mutableBox));
3249+
ccStoreToDest =
3250+
genarr(fir::factory::genMutableBoxRead(builder, loc, mutableBox));
32503251
// If the rhs is scalar, get shape from the allocatable ArrayLoad.
32513252
if (destShape.empty())
32523253
destShape = getShape(destination);
@@ -3310,7 +3311,7 @@ class ArrayExprLowering {
33103311

33113312
/// Entry point into lowering an expression with rank. This entry point is for
33123313
/// lowering a rhs expression, for example. (RefTransparent semantics.)
3313-
static ExtValue lowerSomeNewArrayExpression(
3314+
static ExtValue lowerNewArrayExpression(
33143315
Fortran::lower::AbstractConverter &converter,
33153316
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
33163317
const std::optional<Fortran::evaluate::Shape> &shape,
@@ -3331,7 +3332,7 @@ class ArrayExprLowering {
33313332
fir::dyn_cast_ptrEleTy(tempRes.getType()).cast<fir::SequenceType>();
33323333
if (auto charTy =
33333334
arrTy.getEleTy().template dyn_cast<fir::CharacterType>()) {
3334-
if (charTy.getLen() <= 0)
3335+
if (fir::characterWithDynamicLen(charTy))
33353336
TODO(loc, "CHARACTER does not have constant LEN");
33363337
auto len = builder.createIntegerConstant(
33373338
loc, builder.getCharacterLengthType(), charTy.getLen());
@@ -3340,6 +3341,98 @@ class ArrayExprLowering {
33403341
return fir::ArrayBoxValue(tempRes, dest.getExtents());
33413342
}
33423343

3344+
static ExtValue lowerLazyArrayExpression(
3345+
Fortran::lower::AbstractConverter &converter,
3346+
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3347+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3348+
mlir::Value var) {
3349+
ArrayExprLowering ael{converter, stmtCtx, symMap};
3350+
return ael.lowerLazyArrayExpression(expr, var);
3351+
}
3352+
3353+
ExtValue lowerLazyArrayExpression(
3354+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3355+
mlir::Value var) {
3356+
auto loc = getLoc();
3357+
// Once the loop extents have been computed, which may require being inside
3358+
// some explicit loops, lazily allocate the expression on the heap.
3359+
ccPrelude = [=](llvm::ArrayRef<mlir::Value> shape) {
3360+
auto load = builder.create<fir::LoadOp>(loc, var);
3361+
auto eleTy = fir::unwrapRefType(load.getType());
3362+
auto unknown = fir::SequenceType::getUnknownExtent();
3363+
fir::SequenceType::Shape extents(shape.size(), unknown);
3364+
auto seqTy = fir::SequenceType::get(extents, eleTy);
3365+
auto toTy = fir::HeapType::get(seqTy);
3366+
auto castTo = builder.createConvert(loc, toTy, load);
3367+
auto cmp = builder.genIsNull(loc, castTo);
3368+
builder.genIfThen(loc, cmp)
3369+
.genThen([&]() {
3370+
auto mem = builder.create<fir::AllocMemOp>(loc, seqTy, ".lazy.mask",
3371+
llvm::None, shape);
3372+
auto uncast = builder.createConvert(loc, load.getType(), mem);
3373+
builder.create<fir::StoreOp>(loc, uncast, var);
3374+
})
3375+
.end();
3376+
};
3377+
// Create a dummy array_load before the loop. We're storing to a lazy
3378+
// temporary, so there will be no conflict and no copy-in.
3379+
ccLoadDest = [=](llvm::ArrayRef<mlir::Value> shape) -> fir::ArrayLoadOp {
3380+
auto load = builder.create<fir::LoadOp>(loc, var);
3381+
auto eleTy = fir::unwrapRefType(load.getType());
3382+
auto unknown = fir::SequenceType::getUnknownExtent();
3383+
fir::SequenceType::Shape extents(shape.size(), unknown);
3384+
auto seqTy = fir::SequenceType::get(extents, eleTy);
3385+
auto toTy = fir::HeapType::get(seqTy);
3386+
auto castTo = builder.createConvert(loc, toTy, load);
3387+
auto shapeOp = builder.consShape(loc, shape);
3388+
return builder.create<fir::ArrayLoadOp>(
3389+
loc, seqTy, castTo, shapeOp, /*slice=*/mlir::Value{}, llvm::None);
3390+
};
3391+
// Custom lowering of the element store to deal with the extra indirection
3392+
// to the lazy allocated buffer.
3393+
ccStoreToDest = [=](IterSpace iters) {
3394+
auto load = builder.create<fir::LoadOp>(loc, var);
3395+
auto eleTy = fir::unwrapRefType(load.getType());
3396+
auto unknown = fir::SequenceType::getUnknownExtent();
3397+
fir::SequenceType::Shape extents(iters.iterVec().size(), unknown);
3398+
auto seqTy = fir::SequenceType::get(extents, eleTy);
3399+
auto toTy = fir::HeapType::get(seqTy);
3400+
auto castTo = builder.createConvert(loc, toTy, load);
3401+
auto shape = builder.consShape(loc, genIterationShape());
3402+
auto indices = fir::factory::originateIndices(
3403+
loc, builder, castTo.getType(), shape, iters.iterVec());
3404+
auto eleAddr = builder.create<fir::ArrayCoorOp>(
3405+
loc, builder.getRefType(eleTy), castTo, shape,
3406+
/*slice=*/mlir::Value{}, indices, destination.typeparams());
3407+
auto eleVal = builder.createConvert(loc, eleTy, iters.getElement());
3408+
builder.create<fir::StoreOp>(loc, eleVal, eleAddr);
3409+
return iters.innerArgument();
3410+
};
3411+
auto loopRes = lowerArrayExpression(expr);
3412+
auto load = builder.create<fir::LoadOp>(loc, var);
3413+
auto eleTy = fir::unwrapRefType(load.getType());
3414+
auto unknown = fir::SequenceType::getUnknownExtent();
3415+
fir::SequenceType::Shape extents(genIterationShape().size(), unknown);
3416+
auto seqTy = fir::SequenceType::get(extents, eleTy);
3417+
auto toTy = fir::HeapType::get(seqTy);
3418+
auto tempRes = builder.createConvert(loc, toTy, load);
3419+
builder.create<fir::ArrayMergeStoreOp>(
3420+
loc, destination, fir::getBase(loopRes), tempRes, destination.slice(),
3421+
destination.typeparams());
3422+
auto tempTy = fir::dyn_cast_ptrEleTy(tempRes.getType());
3423+
assert(tempTy && tempTy.isa<fir::SequenceType>() &&
3424+
"must be a reference to an array");
3425+
auto ety = fir::unwrapSequenceType(tempTy);
3426+
if (auto charTy = ety.dyn_cast<fir::CharacterType>()) {
3427+
if (fir::characterWithDynamicLen(charTy))
3428+
TODO(loc, "CHARACTER does not have constant LEN");
3429+
auto len = builder.createIntegerConstant(
3430+
loc, builder.getCharacterLengthType(), charTy.getLen());
3431+
return fir::CharArrayBoxValue(tempRes, len, destination.getExtents());
3432+
}
3433+
return fir::ArrayBoxValue(tempRes, destination.getExtents());
3434+
}
3435+
33433436
void determineShapeOfDest(const fir::ExtendedValue &lhs) {
33443437
destShape = fir::factory::getExtents(builder, getLoc(), lhs);
33453438
}
@@ -3416,9 +3509,9 @@ class ArrayExprLowering {
34163509
auto innerArg = iterSpace.innerArgument();
34173510
auto exv = f(iterSpace);
34183511
mlir::Value upd;
3419-
if (ccDest.hasValue()) {
3512+
if (ccStoreToDest.hasValue()) {
34203513
iterSpace.setElement(std::move(exv));
3421-
upd = fir::getBase(ccDest.getValue()(iterSpace));
3514+
upd = fir::getBase(ccStoreToDest.getValue()(iterSpace));
34223515
} else {
34233516
auto resTy = adjustedArrayElementType(innerArg.getType());
34243517
auto element = adjustedArrayElement(loc, builder, fir::getBase(exv),
@@ -3509,6 +3602,14 @@ class ArrayExprLowering {
35093602
// Mask expressions are array expressions too.
35103603
for (const auto *e : implicitSpace->getExprs())
35113604
if (e && !implicitSpace->isLowered(e)) {
3605+
if (auto var = implicitSpace->lookupMaskVariable(e)) {
3606+
// Allocate the mask buffer lazily.
3607+
auto tmp = Fortran::lower::createLazyArrayTempValue(
3608+
converter, *e, var, symMap, stmtCtx);
3609+
auto shape = builder.createShape(loc, tmp);
3610+
implicitSpace->bind(e, fir::getBase(tmp), shape);
3611+
continue;
3612+
}
35123613
auto optShape =
35133614
Fortran::evaluate::GetShape(converter.getFoldingContext(), *e);
35143615
auto tmp = Fortran::lower::createSomeArrayTempValue(
@@ -3557,6 +3658,10 @@ class ArrayExprLowering {
35573658
const auto loopDepth = loopUppers.size();
35583659
llvm::SmallVector<mlir::Value> ivars;
35593660
if (loopDepth > 0) {
3661+
// Generate the lazy mask allocation, if one was given.
3662+
if (ccPrelude.hasValue())
3663+
ccPrelude.getValue()(shape);
3664+
35603665
auto *startBlock = builder.getBlock();
35613666
for (auto i : llvm::enumerate(llvm::reverse(loopUppers))) {
35623667
if (i.index() > 0) {
@@ -3600,7 +3705,7 @@ class ArrayExprLowering {
36003705
// explicit masks, which are interleaved, these mask expression appear in
36013706
// the innermost loop.
36023707
if (implicitSpaceHasMasks()) {
3603-
auto prependAsNeeded = [&](auto &&indices) {
3708+
auto appendAsNeeded = [&](auto &&indices) {
36043709
llvm::SmallVector<mlir::Value> result;
36053710
result.append(indices.begin(), indices.end());
36063711
return result;
@@ -3614,7 +3719,7 @@ class ArrayExprLowering {
36143719
auto eleRefTy = builder.getRefType(eleTy);
36153720
auto i1Ty = builder.getI1Type();
36163721
// Adjust indices for any shift of the origin of the array.
3617-
auto indexes = prependAsNeeded(fir::factory::originateIndices(
3722+
auto indexes = appendAsNeeded(fir::factory::originateIndices(
36183723
loc, builder, tmp.getType(), shape, iters.iterVec()));
36193724
auto addr = builder.create<fir::ArrayCoorOp>(
36203725
loc, eleRefTy, tmp, shape, /*slice=*/mlir::Value{}, indexes,
@@ -3664,6 +3769,8 @@ class ArrayExprLowering {
36643769
fir::ArrayLoadOp
36653770
createAndLoadSomeArrayTemp(mlir::Type type,
36663771
llvm::ArrayRef<mlir::Value> shape) {
3772+
if (ccLoadDest.hasValue())
3773+
return ccLoadDest.getValue()(shape);
36673774
auto seqTy = type.dyn_cast<fir::SequenceType>();
36683775
assert(seqTy && "must be an array");
36693776
auto loc = getLoc();
@@ -4613,7 +4720,7 @@ class ArrayExprLowering {
46134720
auto loc = getLoc();
46144721
auto memref = fir::getBase(extMemref);
46154722
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(memref.getType());
4616-
assert(arrTy.isa<fir::SequenceType>());
4723+
assert(arrTy.isa<fir::SequenceType>() && "memory ref must be an array");
46174724
auto shape = builder.createShape(loc, extMemref);
46184725
mlir::Value slice;
46194726
if (inSlice) {
@@ -4898,7 +5005,7 @@ class ArrayExprLowering {
48985005
if (isArray(x)) {
48995006
auto e = toEvExpr(x);
49005007
auto sh = Fortran::evaluate::GetShape(converter.getFoldingContext(), e);
4901-
return {lowerSomeNewArrayExpression(converter, symMap, stmtCtx, sh, e),
5008+
return {lowerNewArrayExpression(converter, symMap, stmtCtx, sh, e),
49025009
/*needCopy=*/true};
49035010
}
49045011
return {asScalar(x), /*needCopy=*/true};
@@ -5429,7 +5536,10 @@ class ArrayExprLowering {
54295536
Fortran::lower::StatementContext &stmtCtx;
54305537
Fortran::lower::SymMap &symMap;
54315538
/// The continuation to generate code to update the destination.
5432-
llvm::Optional<CC> ccDest;
5539+
llvm::Optional<CC> ccStoreToDest;
5540+
llvm::Optional<std::function<void(llvm::ArrayRef<mlir::Value>)>> ccPrelude;
5541+
llvm::Optional<std::function<fir::ArrayLoadOp(llvm::ArrayRef<mlir::Value>)>>
5542+
ccLoadDest;
54335543
/// The destination is the loaded array into which the results will be
54345544
/// merged.
54355545
fir::ArrayLoadOp destination;
@@ -5539,8 +5649,18 @@ fir::ExtendedValue Fortran::lower::createSomeArrayTempValue(
55395649
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
55405650
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
55415651
LLVM_DEBUG(expr.AsFortran(llvm::dbgs() << "array value: ") << '\n');
5542-
return ArrayExprLowering::lowerSomeNewArrayExpression(converter, symMap,
5543-
stmtCtx, shape, expr);
5652+
return ArrayExprLowering::lowerNewArrayExpression(converter, symMap, stmtCtx,
5653+
shape, expr);
5654+
}
5655+
5656+
fir::ExtendedValue Fortran::lower::createLazyArrayTempValue(
5657+
Fortran::lower::AbstractConverter &converter,
5658+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
5659+
mlir::Value var, Fortran::lower::SymMap &symMap,
5660+
Fortran::lower::StatementContext &stmtCtx) {
5661+
LLVM_DEBUG(expr.AsFortran(llvm::dbgs() << "array value: ") << '\n');
5662+
return ArrayExprLowering::lowerLazyArrayExpression(converter, symMap, stmtCtx,
5663+
expr, var);
55445664
}
55455665

55465666
fir::ExtendedValue Fortran::lower::createSomeArrayBox(
@@ -5637,6 +5757,9 @@ void Fortran::lower::createArrayMergeStores(
56375757
builder.create<fir::ArrayMergeStoreOp>(
56385758
loc, load, i.value(), load.memref(), load.slice(), load.typeparams());
56395759
}
5760+
// Cleanup any residual mask buffers.
5761+
esp.outermostContext().finalize();
5762+
esp.outermostContext().reset();
56405763
}
56415764
esp.outerLoopStack.pop_back();
56425765
esp.innerArgsStack.pop_back();

0 commit comments

Comments
 (0)