Skip to content

Commit 8130cad

Browse files
committed
Move the scope of the mask buffers to just outside the outermost FORALL context.
Add the ability to allocate array buffers in a lazy manner. This can be helpful if the size of a temporary buffer isn't easily computed before the loop nest begins execution.
1 parent df66885 commit 8130cad

File tree

7 files changed

+406
-175
lines changed

7 files changed

+406
-175
lines changed

flang/include/flang/Lower/ConvertExpr.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ createSomeArrayTempValue(AbstractConverter &converter,
170170
const evaluate::Expr<evaluate::SomeType> &expr,
171171
SymMap &symMap, StatementContext &stmtCtx);
172172

173+
fir::ExtendedValue
174+
createLazyArrayTempValue(AbstractConverter &converter,
175+
const evaluate::Expr<evaluate::SomeType> &expr,
176+
mlir::Value var, SymMap &symMap,
177+
StatementContext &stmtCtx);
178+
173179
/// Lower an array expression to a value of type box. The expression must be a
174180
/// variable.
175181
fir::ExtendedValue

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ class FirOpBuilder : public mlir::OpBuilder {
336336

337337
/// Generate code testing \p addr is not a null address.
338338
mlir::Value genIsNotNull(mlir::Location loc, mlir::Value addr);
339+
340+
/// Generate code testing \p addr is a null address.
341+
mlir::Value genIsNull(mlir::Location loc, mlir::Value addr);
339342

340343
private:
341344
const KindMapping &kindMap;

flang/lib/Lower/Bridge.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
#include "flang/Optimizer/Builder/BoxValue.h"
3232
#include "flang/Optimizer/Builder/Character.h"
3333
#include "flang/Optimizer/Builder/FIRBuilder.h"
34+
#include "flang/Optimizer/Builder/Runtime/Character.h"
3435
#include "flang/Optimizer/Dialect/FIRAttr.h"
3536
#include "flang/Optimizer/Dialect/FIRDialect.h"
3637
#include "flang/Optimizer/Dialect/FIROps.h"
37-
#include "flang/Optimizer/Builder/Runtime/Character.h"
3838
#include "flang/Optimizer/Support/FIRContext.h"
3939
#include "flang/Optimizer/Support/FatalError.h"
4040
#include "flang/Optimizer/Support/InternalNames.h"
@@ -2578,8 +2578,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25782578
analyzeExplicitSpace(e.operator->());
25792579
}
25802580
void analyzeExplicitSpace(const Fortran::parser::WhereConstructStmt &ws) {
2581-
analyzeExplicitSpace(*Fortran::semantics::GetExpr(
2582-
std::get<Fortran::parser::LogicalExpr>(ws.t)));
2581+
auto *exp = Fortran::semantics::GetExpr(
2582+
std::get<Fortran::parser::LogicalExpr>(ws.t));
2583+
addMaskVariable(exp);
2584+
analyzeExplicitSpace(*exp);
25832585
}
25842586
void analyzeExplicitSpace(
25852587
const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) {
@@ -2602,8 +2604,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26022604
body.u);
26032605
}
26042606
void analyzeExplicitSpace(const Fortran::parser::MaskedElsewhereStmt &stmt) {
2605-
analyzeExplicitSpace(*Fortran::semantics::GetExpr(
2606-
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
2607+
auto *exp = Fortran::semantics::GetExpr(
2608+
std::get<Fortran::parser::LogicalExpr>(stmt.t));
2609+
addMaskVariable(exp);
2610+
analyzeExplicitSpace(*exp);
26072611
}
26082612
void
26092613
analyzeExplicitSpace(const Fortran::parser::WhereConstruct::Elsewhere *ew) {
@@ -2612,8 +2616,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26122616
analyzeExplicitSpace(e);
26132617
}
26142618
void analyzeExplicitSpace(const Fortran::parser::WhereStmt &stmt) {
2615-
analyzeExplicitSpace(*Fortran::semantics::GetExpr(
2616-
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
2619+
auto *exp = Fortran::semantics::GetExpr(
2620+
std::get<Fortran::parser::LogicalExpr>(stmt.t));
2621+
addMaskVariable(exp);
2622+
analyzeExplicitSpace(*exp);
26172623
const auto &assign =
26182624
std::get<Fortran::parser::AssignmentStmt>(stmt.t).typedAssignment->v;
26192625
assert(assign.has_value() && "WHERE has no statement");
@@ -2659,8 +2665,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26592665
}
26602666
analyzeExplicitSpacePop();
26612667
}
2668+
26622669
void analyzeExplicitSpacePop() { explicitIterSpace.popLevel(); }
26632670

2671+
void addMaskVariable(Fortran::lower::FrontEndExpr exp) {
2672+
// Note: use i8 to store bool values. This avoids round-down behavior found
2673+
// with sequences of i1. That is, an array of i1 will be truncated in size
2674+
// and be too small. For example, a buffer of type fir.array<7xi1> will have
2675+
// 0 size.
2676+
auto ty = fir::HeapType::get(builder->getIntegerType(8));
2677+
auto loc = toLocation();
2678+
auto var = builder->createTemporary(loc, ty);
2679+
auto nil = builder->createNullConstant(loc, ty);
2680+
builder->create<fir::StoreOp>(loc, nil, var);
2681+
implicitIterSpace.addMaskVariable(exp, var);
2682+
explicitIterSpace.outermostContext().attachCleanup([=]() {
2683+
auto load = builder->create<fir::LoadOp>(loc, var);
2684+
auto cmp = builder->genIsNotNull(loc, load);
2685+
auto ifOp =
2686+
builder->create<fir::IfOp>(loc, cmp, /*withElseRegion=*/false);
2687+
auto insPt = builder->saveInsertionPoint();
2688+
builder->setInsertionPointToStart(&ifOp.thenRegion().front());
2689+
builder->create<fir::FreeMemOp>(loc, load);
2690+
builder->restoreInsertionPoint(insPt);
2691+
});
2692+
}
2693+
26642694
//===--------------------------------------------------------------------===//
26652695

26662696
Fortran::lower::LoweringBridge &bridge;

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 143 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,10 +3108,10 @@ class ArrayExprLowering {
31083108
void lowerArrayAssignment(const TL &lhs, const TR &rhs) {
31093109
auto loc = getLoc();
31103110
/// Here the target subspace is not necessarily contiguous. The ArrayUpdate
3111-
/// continuation is implicitly returned in `ccDest` and the ArrayLoad in
3112-
/// `destination`.
3111+
/// continuation is implicitly returned in `ccStoreToDest` and the ArrayLoad
3112+
/// in `destination`.
31133113
PushSemantics(ConstituentSemantics::ProjectedCopyInCopyOut);
3114-
ccDest = genarr(lhs);
3114+
ccStoreToDest = genarr(lhs);
31153115
determineShapeOfDest(lhs);
31163116
semant = ConstituentSemantics::RefTransparent;
31173117
auto exv = lowerArrayExpression(rhs);
@@ -3143,7 +3143,7 @@ class ArrayExprLowering {
31433143
newIters.prependIndexValue(i);
31443144
return newIters;
31453145
};
3146-
ccDest = [=](IterSpace iters) { return lambda(pc(iters)); };
3146+
ccStoreToDest = [=](IterSpace iters) { return lambda(pc(iters)); };
31473147
destShape.assign(extents.begin(), extents.end());
31483148
semant = ConstituentSemantics::RefTransparent;
31493149
auto exv = lowerArrayExpression(rhs);
@@ -3246,7 +3246,8 @@ class ArrayExprLowering {
32463246
destShape, lengthParams);
32473247
// Create ArrayLoad for the mutable box and save it into `destination`.
32483248
PushSemantics(ConstituentSemantics::ProjectedCopyInCopyOut);
3249-
ccDest = genarr(fir::factory::genMutableBoxRead(builder, loc, mutableBox));
3249+
ccStoreToDest =
3250+
genarr(fir::factory::genMutableBoxRead(builder, loc, mutableBox));
32503251
// If the rhs is scalar, get shape from the allocatable ArrayLoad.
32513252
if (destShape.empty())
32523253
destShape = getShape(destination);
@@ -3310,7 +3311,7 @@ class ArrayExprLowering {
33103311

33113312
/// Entry point into lowering an expression with rank. This entry point is for
33123313
/// lowering a rhs expression, for example. (RefTransparent semantics.)
3313-
static ExtValue lowerSomeNewArrayExpression(
3314+
static ExtValue lowerNewArrayExpression(
33143315
Fortran::lower::AbstractConverter &converter,
33153316
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
33163317
const std::optional<Fortran::evaluate::Shape> &shape,
@@ -3331,7 +3332,7 @@ class ArrayExprLowering {
33313332
fir::dyn_cast_ptrEleTy(tempRes.getType()).cast<fir::SequenceType>();
33323333
if (auto charTy =
33333334
arrTy.getEleTy().template dyn_cast<fir::CharacterType>()) {
3334-
if (charTy.getLen() <= 0)
3335+
if (fir::characterWithDynamicLen(charTy))
33353336
TODO(loc, "CHARACTER does not have constant LEN");
33363337
auto len = builder.createIntegerConstant(
33373338
loc, builder.getCharacterLengthType(), charTy.getLen());
@@ -3340,6 +3341,99 @@ class ArrayExprLowering {
33403341
return fir::ArrayBoxValue(tempRes, dest.getExtents());
33413342
}
33423343

3344+
static ExtValue lowerLazyArrayExpression(
3345+
Fortran::lower::AbstractConverter &converter,
3346+
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3347+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3348+
mlir::Value var) {
3349+
ArrayExprLowering ael{converter, stmtCtx, symMap};
3350+
return ael.lowerLazyArrayExpression(expr, var);
3351+
}
3352+
3353+
ExtValue lowerLazyArrayExpression(
3354+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
3355+
mlir::Value var) {
3356+
auto loc = getLoc();
3357+
// Once the loop extents have been computed, which may require being inside
3358+
// some explicit loops, lazily allocate the expression on the heap.
3359+
ccPrelude = [=](llvm::ArrayRef<mlir::Value> shape) -> mlir::Value {
3360+
auto load = builder.create<fir::LoadOp>(loc, var);
3361+
auto eleTy = fir::unwrapRefType(load.getType());
3362+
auto unknown = fir::SequenceType::getUnknownExtent();
3363+
fir::SequenceType::Shape extents(shape.size(), unknown);
3364+
auto seqTy = fir::SequenceType::get(extents, eleTy);
3365+
auto toTy = fir::HeapType::get(seqTy);
3366+
auto castTo = builder.createConvert(loc, toTy, load);
3367+
auto cmp = builder.genIsNull(loc, castTo);
3368+
auto ifOp = builder.create<fir::IfOp>(loc, cmp, /*withElseRegion=*/false);
3369+
auto insPt = builder.saveInsertionPoint();
3370+
builder.setInsertionPointToStart(&ifOp.thenRegion().front());
3371+
auto mem = builder.create<fir::AllocMemOp>(loc, seqTy, ".lazy.mask",
3372+
llvm::None, shape);
3373+
auto uncast = builder.createConvert(loc, load.getType(), mem);
3374+
builder.create<fir::StoreOp>(loc, uncast, var);
3375+
builder.restoreInsertionPoint(insPt);
3376+
return mem;
3377+
};
3378+
// Create a dummy array_load before the loop. We're storing to a lazy
3379+
// temporary, so there will be no conflict and no copy-in.
3380+
ccLoadDest = [=](llvm::ArrayRef<mlir::Value> shape) -> fir::ArrayLoadOp {
3381+
auto load = builder.create<fir::LoadOp>(loc, var);
3382+
auto eleTy = fir::unwrapRefType(load.getType());
3383+
auto unknown = fir::SequenceType::getUnknownExtent();
3384+
fir::SequenceType::Shape extents(shape.size(), unknown);
3385+
auto seqTy = fir::SequenceType::get(extents, eleTy);
3386+
auto toTy = fir::HeapType::get(seqTy);
3387+
auto castTo = builder.createConvert(loc, toTy, load);
3388+
auto shapeOp = builder.consShape(loc, shape);
3389+
return builder.create<fir::ArrayLoadOp>(
3390+
loc, seqTy, castTo, shapeOp, /*slice=*/mlir::Value{}, llvm::None);
3391+
};
3392+
// Custom lowering of the element store to deal with the extra indirection
3393+
// to the lazy allocated buffer.
3394+
ccStoreToDest = [=](IterSpace iters) {
3395+
auto load = builder.create<fir::LoadOp>(loc, var);
3396+
auto eleTy = fir::unwrapRefType(load.getType());
3397+
auto unknown = fir::SequenceType::getUnknownExtent();
3398+
fir::SequenceType::Shape extents(iters.iterVec().size(), unknown);
3399+
auto seqTy = fir::SequenceType::get(extents, eleTy);
3400+
auto toTy = fir::HeapType::get(seqTy);
3401+
auto castTo = builder.createConvert(loc, toTy, load);
3402+
auto shape = builder.consShape(loc, genIterationShape());
3403+
auto indices = fir::factory::originateIndices(
3404+
loc, builder, castTo.getType(), shape, iters.iterVec());
3405+
auto eleAddr = builder.create<fir::ArrayCoorOp>(
3406+
loc, builder.getRefType(eleTy), castTo, shape,
3407+
/*slice=*/mlir::Value{}, indices, destination.typeparams());
3408+
auto eleVal = builder.createConvert(loc, eleTy, iters.getElement());
3409+
builder.create<fir::StoreOp>(loc, eleVal, eleAddr);
3410+
return iters.innerArgument();
3411+
};
3412+
auto loopRes = lowerArrayExpression(expr);
3413+
auto load = builder.create<fir::LoadOp>(loc, var);
3414+
auto eleTy = fir::unwrapRefType(load.getType());
3415+
auto unknown = fir::SequenceType::getUnknownExtent();
3416+
fir::SequenceType::Shape extents(genIterationShape().size(), unknown);
3417+
auto seqTy = fir::SequenceType::get(extents, eleTy);
3418+
auto toTy = fir::HeapType::get(seqTy);
3419+
auto tempRes = builder.createConvert(loc, toTy, load);
3420+
builder.create<fir::ArrayMergeStoreOp>(
3421+
loc, destination, fir::getBase(loopRes), tempRes, destination.slice(),
3422+
destination.typeparams());
3423+
auto tempTy = fir::dyn_cast_ptrEleTy(tempRes.getType());
3424+
assert(tempTy && tempTy.isa<fir::SequenceType>() &&
3425+
"must be a reference to an array");
3426+
auto ety = fir::unwrapSequenceType(tempTy);
3427+
if (auto charTy = ety.dyn_cast<fir::CharacterType>()) {
3428+
if (fir::characterWithDynamicLen(charTy))
3429+
TODO(loc, "CHARACTER does not have constant LEN");
3430+
auto len = builder.createIntegerConstant(
3431+
loc, builder.getCharacterLengthType(), charTy.getLen());
3432+
return fir::CharArrayBoxValue(tempRes, len, destination.getExtents());
3433+
}
3434+
return fir::ArrayBoxValue(tempRes, destination.getExtents());
3435+
}
3436+
33433437
void determineShapeOfDest(const fir::ExtendedValue &lhs) {
33443438
destShape = fir::factory::getExtents(builder, getLoc(), lhs);
33453439
}
@@ -3416,9 +3510,9 @@ class ArrayExprLowering {
34163510
auto innerArg = iterSpace.innerArgument();
34173511
auto exv = f(iterSpace);
34183512
mlir::Value upd;
3419-
if (ccDest.hasValue()) {
3513+
if (ccStoreToDest.hasValue()) {
34203514
iterSpace.setElement(std::move(exv));
3421-
upd = fir::getBase(ccDest.getValue()(iterSpace));
3515+
upd = fir::getBase(ccStoreToDest.getValue()(iterSpace));
34223516
} else {
34233517
auto resTy = adjustedArrayElementType(innerArg.getType());
34243518
auto element = adjustedArrayElement(loc, builder, fir::getBase(exv),
@@ -3509,6 +3603,14 @@ class ArrayExprLowering {
35093603
// Mask expressions are array expressions too.
35103604
for (const auto *e : implicitSpace->getExprs())
35113605
if (e && !implicitSpace->isLowered(e)) {
3606+
if (auto var = implicitSpace->lookupVariable(e)) {
3607+
// Allocate the mask buffer lazily.
3608+
auto tmp = Fortran::lower::createLazyArrayTempValue(
3609+
converter, *e, var, symMap, stmtCtx);
3610+
auto shape = builder.createShape(loc, tmp);
3611+
implicitSpace->bind(e, fir::getBase(tmp), shape);
3612+
continue;
3613+
}
35123614
auto optShape =
35133615
Fortran::evaluate::GetShape(converter.getFoldingContext(), *e);
35143616
auto tmp = Fortran::lower::createSomeArrayTempValue(
@@ -3557,6 +3659,12 @@ class ArrayExprLowering {
35573659
const auto loopDepth = loopUppers.size();
35583660
llvm::SmallVector<mlir::Value> ivars;
35593661
if (loopDepth > 0) {
3662+
// Generate the lazy mask allocation, if one was given.
3663+
if (ccPrelude.hasValue()) {
3664+
[[maybe_unused]] auto allocMem = ccPrelude.getValue()(shape);
3665+
assert(allocMem && "mask buffer allocation failure");
3666+
}
3667+
35603668
auto *startBlock = builder.getBlock();
35613669
for (auto i : llvm::enumerate(llvm::reverse(loopUppers))) {
35623670
if (i.index() > 0) {
@@ -3600,7 +3708,7 @@ class ArrayExprLowering {
36003708
// explicit masks, which are interleaved, these mask expression appear in
36013709
// the innermost loop.
36023710
if (implicitSpaceHasMasks()) {
3603-
auto prependAsNeeded = [&](auto &&indices) {
3711+
auto appendAsNeeded = [&](auto &&indices) {
36043712
llvm::SmallVector<mlir::Value> result;
36053713
result.append(indices.begin(), indices.end());
36063714
return result;
@@ -3614,7 +3722,7 @@ class ArrayExprLowering {
36143722
auto eleRefTy = builder.getRefType(eleTy);
36153723
auto i1Ty = builder.getI1Type();
36163724
// Adjust indices for any shift of the origin of the array.
3617-
auto indexes = prependAsNeeded(fir::factory::originateIndices(
3725+
auto indexes = appendAsNeeded(fir::factory::originateIndices(
36183726
loc, builder, tmp.getType(), shape, iters.iterVec()));
36193727
auto addr = builder.create<fir::ArrayCoorOp>(
36203728
loc, eleRefTy, tmp, shape, /*slice=*/mlir::Value{}, indexes,
@@ -3664,6 +3772,8 @@ class ArrayExprLowering {
36643772
fir::ArrayLoadOp
36653773
createAndLoadSomeArrayTemp(mlir::Type type,
36663774
llvm::ArrayRef<mlir::Value> shape) {
3775+
if (ccLoadDest.hasValue())
3776+
return ccLoadDest.getValue()(shape);
36673777
auto seqTy = type.dyn_cast<fir::SequenceType>();
36683778
assert(seqTy && "must be an array");
36693779
auto loc = getLoc();
@@ -4613,7 +4723,7 @@ class ArrayExprLowering {
46134723
auto loc = getLoc();
46144724
auto memref = fir::getBase(extMemref);
46154725
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(memref.getType());
4616-
assert(arrTy.isa<fir::SequenceType>());
4726+
assert(arrTy.isa<fir::SequenceType>() && "memory ref must be an array");
46174727
auto shape = builder.createShape(loc, extMemref);
46184728
mlir::Value slice;
46194729
if (inSlice) {
@@ -4898,7 +5008,7 @@ class ArrayExprLowering {
48985008
if (isArray(x)) {
48995009
auto e = toEvExpr(x);
49005010
auto sh = Fortran::evaluate::GetShape(converter.getFoldingContext(), e);
4901-
return {lowerSomeNewArrayExpression(converter, symMap, stmtCtx, sh, e),
5011+
return {lowerNewArrayExpression(converter, symMap, stmtCtx, sh, e),
49025012
/*needCopy=*/true};
49035013
}
49045014
return {asScalar(x), /*needCopy=*/true};
@@ -5429,7 +5539,11 @@ class ArrayExprLowering {
54295539
Fortran::lower::StatementContext &stmtCtx;
54305540
Fortran::lower::SymMap &symMap;
54315541
/// The continuation to generate code to update the destination.
5432-
llvm::Optional<CC> ccDest;
5542+
llvm::Optional<CC> ccStoreToDest;
5543+
llvm::Optional<std::function<mlir::Value(llvm::ArrayRef<mlir::Value>)>>
5544+
ccPrelude;
5545+
llvm::Optional<std::function<fir::ArrayLoadOp(llvm::ArrayRef<mlir::Value>)>>
5546+
ccLoadDest;
54335547
/// The destination is the loaded array into which the results will be
54345548
/// merged.
54355549
fir::ArrayLoadOp destination;
@@ -5539,8 +5653,18 @@ fir::ExtendedValue Fortran::lower::createSomeArrayTempValue(
55395653
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
55405654
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
55415655
LLVM_DEBUG(expr.AsFortran(llvm::dbgs() << "array value: ") << '\n');
5542-
return ArrayExprLowering::lowerSomeNewArrayExpression(converter, symMap,
5543-
stmtCtx, shape, expr);
5656+
return ArrayExprLowering::lowerNewArrayExpression(converter, symMap, stmtCtx,
5657+
shape, expr);
5658+
}
5659+
5660+
fir::ExtendedValue Fortran::lower::createLazyArrayTempValue(
5661+
Fortran::lower::AbstractConverter &converter,
5662+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr,
5663+
mlir::Value var, Fortran::lower::SymMap &symMap,
5664+
Fortran::lower::StatementContext &stmtCtx) {
5665+
LLVM_DEBUG(expr.AsFortran(llvm::dbgs() << "array value: ") << '\n');
5666+
return ArrayExprLowering::lowerLazyArrayExpression(converter, symMap, stmtCtx,
5667+
expr, var);
55445668
}
55455669

55465670
fir::ExtendedValue Fortran::lower::createSomeArrayBox(
@@ -5637,6 +5761,9 @@ void Fortran::lower::createArrayMergeStores(
56375761
builder.create<fir::ArrayMergeStoreOp>(
56385762
loc, load, i.value(), load.memref(), load.slice(), load.typeparams());
56395763
}
5764+
// Cleanup any residual mask buffers.
5765+
esp.outermostContext().finalize();
5766+
esp.outermostContext().reset();
56405767
}
56415768
esp.outerLoopStack.pop_back();
56425769
esp.innerArgsStack.pop_back();

0 commit comments

Comments
 (0)