Skip to content

Commit 6114f8f

Browse files
authored
Merge pull request #996 from schweitzpgi/ch-forall4
Fix assertion "key not already in map".
2 parents ba7b1b1 + 6b03818 commit 6114f8f

File tree

4 files changed

+334
-94
lines changed

4 files changed

+334
-94
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1912,6 +1912,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19121912
},
19131913
[&](const Fortran::evaluate::ProcedureRef &procRef) {
19141914
// User defined assignment: call the procedure.
1915+
if (explicitIterationSpace())
1916+
TODO(loc, "user defined assignment within FORALL");
19151917
Fortran::semantics::SomeExpr expr{procRef};
19161918
createFIRExpr(toLocation(), &expr, stmtCtx);
19171919
},

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 190 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2412,10 +2412,17 @@ class ArrayExprLowering {
24122412
};
24132413

24142414
using ExtValue = fir::ExtendedValue;
2415-
using IterSpace = const IterationSpace &; // active iteration space
2416-
using CC = std::function<ExtValue(IterSpace)>; // current continuation
2417-
using PC =
2418-
std::function<IterationSpace(IterSpace)>; // projection continuation
2415+
/// Active iteration space.
2416+
using IterSpace = const IterationSpace &;
2417+
/// Current continuation. Function that will generate IR for a single
2418+
/// iteration of the pending iterative loop structure.
2419+
using CC = std::function<ExtValue(IterSpace)>;
2420+
/// Projection continuation. Function that will project one iteration space
2421+
/// into another.
2422+
using PC = std::function<IterationSpace(IterSpace)>;
2423+
/// Loop bounds continuation. Function that will generate IR to compute loop
2424+
/// bounds in a future context.
2425+
using LBC = std::function<llvm::SmallVector<mlir::Value>()>;
24192426
using ArrayBaseTy =
24202427
std::variant<std::monostate, const Fortran::evaluate::ArrayRef *,
24212428
const Fortran::evaluate::DataRef *>;
@@ -2686,13 +2693,15 @@ class ArrayExprLowering {
26862693
}
26872694
}
26882695

2696+
/// Returns true iff the Ev::Shape is constant.
26892697
static bool evalShapeIsConstant(const Fortran::evaluate::Shape &shape) {
26902698
for (const auto &s : shape)
26912699
if (!s || !Fortran::evaluate::IsConstantExpr(*s))
26922700
return false;
26932701
return true;
26942702
}
26952703

2704+
/// Convert an Ev::Shape to IR values.
26962705
void convertFEShape(const Fortran::evaluate::Shape &shape,
26972706
llvm::SmallVectorImpl<mlir::Value> &result) {
26982707
if (evalShapeIsConstant(shape)) {
@@ -2831,7 +2840,7 @@ class ArrayExprLowering {
28312840
/// this returns any implicit shape component, if it exists.
28322841
llvm::SmallVector<mlir::Value> genIterationShape() {
28332842
if (explicitSpace)
2834-
return explicitImpliedShape;
2843+
return {};
28352844
// Use the precomputed destination shape.
28362845
if (!destShape.empty())
28372846
return destShape;
@@ -3199,6 +3208,51 @@ class ArrayExprLowering {
31993208
return {indices, loops[0]};
32003209
}
32013210

3211+
llvm::SmallVector<mlir::Value> genImplicitLoopBounds(
3212+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &e) {
3213+
struct Filter : public Fortran::evaluate::AnyTraverse<
3214+
Filter, std::optional<llvm::SmallVector<mlir::Value>>> {
3215+
using Base = Fortran::evaluate::AnyTraverse<
3216+
Filter, std::optional<llvm::SmallVector<mlir::Value>>>;
3217+
using Base::operator();
3218+
3219+
Filter(const llvm::SmallVector<mlir::Value> init)
3220+
: Base(*this), bounds(init) {}
3221+
3222+
std::optional<llvm::SmallVector<mlir::Value>>
3223+
operator()(const Fortran::evaluate::ArrayRef &ref) {
3224+
if ((ref.base().IsSymbol() || ref.base().Rank() == 0) &&
3225+
ref.Rank() > 0 && !ref.subscript().empty()) {
3226+
assert(ref.subscript().size() == bounds.size());
3227+
llvm::SmallVector<mlir::Value> result;
3228+
auto bdIter = bounds.begin();
3229+
for (auto ss : ref.subscript()) {
3230+
mlir::Value bound = *bdIter++;
3231+
std::visit(Fortran::common::visitors{
3232+
[&](const Fortran::evaluate::Triplet &triple) {
3233+
result.push_back(bound);
3234+
},
3235+
[&](const auto &intExpr) {
3236+
if (intExpr.value().Rank() > 0)
3237+
result.push_back(bound);
3238+
}},
3239+
ss.u);
3240+
}
3241+
return {result};
3242+
}
3243+
return {};
3244+
}
3245+
3246+
llvm::SmallVector<mlir::Value> bounds;
3247+
};
3248+
3249+
auto originalShape = getShape(converter.genType(e));
3250+
Filter filter(originalShape);
3251+
if (auto res = filter(e))
3252+
return *res;
3253+
return originalShape;
3254+
}
3255+
32023256
void genMasks() {
32033257
auto loc = getLoc();
32043258
// Lower explicit mask expressions, if any.
@@ -3215,8 +3269,8 @@ class ArrayExprLowering {
32153269
for (const auto *e : masks->getExprs())
32163270
if (e && !masks->isLowered(e)) {
32173271
auto extents = genExplicitExtents();
3218-
extents.append(explicitImpliedShape.rbegin(),
3219-
explicitImpliedShape.rend());
3272+
auto loopBounds = genImplicitLoopBounds(*e);
3273+
extents.append(loopBounds.rbegin(), loopBounds.rend());
32203274
// Allocate a temporary to cache the mask results.
32213275
auto tmpShape = builder.consShape(loc, extents);
32223276
auto tmp = createAndLoadSomeArrayTemp(
@@ -3230,7 +3284,8 @@ class ArrayExprLowering {
32303284
// Evaluate like any other nested array expression.
32313285
ArrayExprLowering ael{converter, masks->stmtContext(), symMap,
32323286
ConstituentSemantics::ProjectedCopyInCopyOut};
3233-
ael.lowerArrayAssignment(tmp, *e, indices, explicitImpliedShape);
3287+
ael.lowerArrayAssignment(tmp, *e, indices,
3288+
explicitImpliedLoopBounds.getValue()());
32343289
masks->bind(e, tmp.memref(), tmpShape);
32353290
builder.setInsertionPointAfter(loop0);
32363291
builder.create<fir::ArrayMergeStoreOp>(loc, tmp, loop0.getResult(0),
@@ -3289,14 +3344,18 @@ class ArrayExprLowering {
32893344
llvm::SmallVector<fir::DoLoopOp> loops;
32903345
llvm::SmallVector<mlir::Value> explicitOffsets;
32913346
// FORALL loops are outermost.
3292-
if (explicitSpace)
3347+
if (explicitSpace) {
32933348
genExplicitIterSpace(loops, explicitOffsets, innerArg);
3349+
if (explicitImpliedLoopBounds.hasValue())
3350+
loopUppers = explicitImpliedLoopBounds.getValue()();
3351+
}
32943352

32953353
// Now handle the implicit loops.
32963354
const auto loopFirst = loops.size();
32973355
const auto loopDepth = loopUppers.size();
32983356
llvm::SmallVector<mlir::Value> ivars;
32993357
if (loopDepth > 0) {
3358+
auto *startBlock = builder.getBlock();
33003359
for (auto i : llvm::enumerate(llvm::reverse(loopUppers))) {
33013360
if (i.index() > 0) {
33023361
assert(!loops.empty());
@@ -3311,8 +3370,11 @@ class ArrayExprLowering {
33113370
}
33123371
// Add the fir.result for all loops except the innermost one. We must also
33133372
// terminate the innermost explicit bounds loop here as well.
3314-
for (std::remove_const_t<decltype(loopFirst)> i =
3315-
loopFirst ? loopFirst - 1 : 0;
3373+
if (loopFirst > 0) {
3374+
builder.setInsertionPointToEnd(startBlock);
3375+
builder.create<fir::ResultOp>(loc, loops[loopFirst].getResult(0));
3376+
}
3377+
for (std::remove_const_t<decltype(loopFirst)> i = loopFirst;
33163378
i + 1 < loopFirst + loopDepth; ++i) {
33173379
builder.setInsertionPointToEnd(loops[i].getBody());
33183380
builder.create<fir::ResultOp>(loc, loops[i + 1].getResult(0));
@@ -3365,17 +3427,18 @@ class ArrayExprLowering {
33653427
// structure is produced.
33663428
auto maskExprs = masks->getExprs();
33673429
const auto size = maskExprs.size() - 1;
3368-
for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i) {
3369-
auto ifOp = builder.create<fir::IfOp>(
3370-
loc, mlir::TypeRange{innerArg.getType()},
3371-
fir::getBase(
3372-
genCond(masks->getBindingWithShape(maskExprs[i]), iters)),
3373-
/*withElseRegion=*/true);
3374-
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
3375-
builder.setInsertionPointToStart(&ifOp.thenRegion().front());
3376-
builder.create<fir::ResultOp>(loc, innerArg);
3377-
builder.setInsertionPointToStart(&ifOp.elseRegion().front());
3378-
}
3430+
for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i)
3431+
if (maskExprs[i]) {
3432+
auto ifOp = builder.create<fir::IfOp>(
3433+
loc, mlir::TypeRange{innerArg.getType()},
3434+
fir::getBase(
3435+
genCond(masks->getBindingWithShape(maskExprs[i]), iters)),
3436+
/*withElseRegion=*/true);
3437+
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
3438+
builder.setInsertionPointToStart(&ifOp.thenRegion().front());
3439+
builder.create<fir::ResultOp>(loc, innerArg);
3440+
builder.setInsertionPointToStart(&ifOp.elseRegion().front());
3441+
}
33793442

33803443
// The last condition is either non-negated or unconditionally negated.
33813444
if (maskExprs[size]) {
@@ -4285,11 +4348,8 @@ class ArrayExprLowering {
42854348
template <typename A>
42864349
std::pair<CC, mlir::Type> raiseRankedBase(const A &x) {
42874350
auto result = raiseBase(x);
4288-
if (isProjectedCopyInCopyOut()) {
4289-
auto optShape = Fortran::evaluate::GetShape(x);
4290-
assert(optShape.has_value());
4291-
convertFEShape(*optShape, explicitImpliedShape);
4292-
}
4351+
if (isProjectedCopyInCopyOut())
4352+
explicitImpliedLoopBounds = [=]() { return getShape(x); };
42934353
return result;
42944354
}
42954355
template <typename A>
@@ -4320,11 +4380,8 @@ class ArrayExprLowering {
43204380
std::pair<CC, mlir::Type> raiseRankedComponent(llvm::Optional<CC> cc,
43214381
const A &x, mlir::Type inTy) {
43224382
auto result = raiseComponent(cc, x, inTy, false);
4323-
if (isProjectedCopyInCopyOut()) {
4324-
auto optShape = Fortran::evaluate::GetShape(x);
4325-
assert(optShape.has_value());
4326-
convertFEShape(*optShape, explicitImpliedShape);
4327-
}
4383+
if (isProjectedCopyInCopyOut())
4384+
explicitImpliedLoopBounds = [=]() { return getShape(x); };
43284385
return result;
43294386
}
43304387

@@ -4393,8 +4450,7 @@ class ArrayExprLowering {
43934450
auto &sym = base.GetFirstSymbol();
43944451
if (x.Rank() > 0 || accessUsesControlVariable()) {
43954452
auto [fopt2, ty2] = raiseBase(sym);
4396-
return RaiseRT{fopt2, fir::unwrapSequenceType(ty2), false,
4397-
x.Rank() > 0};
4453+
return RaiseRT{fopt2, ty2, false, x.Rank() > 0};
43984454
}
43994455
return RaiseRT{llvm::None, mlir::Type{}, false, false};
44004456
}
@@ -4411,11 +4467,53 @@ class ArrayExprLowering {
44114467
}(),
44124468
x);
44134469
}
4470+
static mlir::Type unwrapBoxEleTy(mlir::Type ty) {
4471+
if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
4472+
ty = boxTy.getEleTy();
4473+
if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
4474+
ty = refTy;
4475+
}
4476+
return ty;
4477+
}
4478+
llvm::SmallVector<mlir::Value> getShape(mlir::Type ty) {
4479+
llvm::SmallVector<mlir::Value> result;
4480+
ty = unwrapBoxEleTy(ty);
4481+
auto loc = getLoc();
4482+
auto idxTy = builder.getIndexType();
4483+
for (auto extent : ty.cast<fir::SequenceType>().getShape()) {
4484+
auto v = extent == fir::SequenceType::getUnknownExtent()
4485+
? builder.create<fir::UndefOp>(loc, idxTy).getResult()
4486+
: builder.createIntegerConstant(loc, idxTy, extent);
4487+
result.push_back(v);
4488+
}
4489+
return result;
4490+
}
4491+
llvm::SmallVector<mlir::Value>
4492+
getShape(const Fortran::semantics::SymbolRef &x) {
4493+
if (x.get().Rank() == 0)
4494+
return {};
4495+
return getFrontEndShape(x);
4496+
}
4497+
template <typename A>
4498+
llvm::SmallVector<mlir::Value> getShape(const A &x) {
4499+
if (x.Rank() == 0)
4500+
return {};
4501+
return getFrontEndShape(x);
4502+
}
4503+
template <typename A>
4504+
llvm::SmallVector<mlir::Value> getFrontEndShape(const A &x) {
4505+
if (auto optShape = Fortran::evaluate::GetShape(x)) {
4506+
llvm::SmallVector<mlir::Value> result;
4507+
convertFEShape(*optShape, result);
4508+
return result;
4509+
}
4510+
return {};
4511+
}
44144512
RaiseRT raiseSubscript(const RaiseRT &tup,
44154513
const Fortran::evaluate::ArrayRef &x) {
44164514
auto fopt = std::get<llvm::Optional<CC>>(tup);
44174515
if (fopt.hasValue()) {
4418-
auto ty = std::get<mlir::Type>(tup);
4516+
auto arrTy = std::get<mlir::Type>(tup);
44194517
auto prevRanked = std::get<2>(tup);
44204518
auto ranked = std::get<3>(tup);
44214519
auto lambda = fopt.getValue();
@@ -4429,30 +4527,64 @@ class ArrayExprLowering {
44294527
// from the explicit space, then those dimensions should not be
44304528
// considered as contributing to the implied part of the iteration
44314529
// space.
4432-
if (explicitImpliedShape.empty()) {
4433-
assert(destination && "destination must be set");
4434-
auto feShape = getShape(destination);
4530+
if (!explicitImpliedLoopBounds.hasValue()) {
44354531
if (subs.empty()) {
4436-
explicitImpliedShape.assign(feShape);
4532+
explicitImpliedLoopBounds = [=]() { return getShape(x); };
44374533
} else {
4438-
unsigned ii = 0;
4534+
auto desShape = getShape(x);
44394535
unsigned vi = 0;
4440-
vectorCoor.resize(feShape.size());
4536+
vectorCoor.resize(desShape.size());
44414537
// Filter out subscripts that are scalar expressions. If it is a
4442-
// scalar expression it is either loop-invariant or a function of
4443-
// the explicit loop control variables.
4444-
for (const auto &ss : subs)
4538+
// scalar expression it is either loop-invariant or a function
4539+
// of the explicit loop control variables.
4540+
for (const auto &ss : subs) {
44454541
if (auto *intExpr = std::get_if<
4446-
Fortran::evaluate::IndirectSubscriptIntegerExpr>(&ss.u)) {
4447-
if (intExpr->value().Rank() > 0) {
4448-
explicitImpliedShape.push_back(feShape[ii++]);
4542+
Fortran::evaluate::IndirectSubscriptIntegerExpr>(&ss.u))
4543+
if (intExpr->value().Rank() > 0)
44494544
vectorCoor[vi++] = genarr(intExpr->value());
4450-
}
4451-
} else {
4452-
// This is a triple which may be using an explicit control
4453-
// variable.
4454-
explicitImpliedShape.push_back(feShape[ii++]);
4455-
}
4545+
}
4546+
explicitImpliedLoopBounds = [=]() {
4547+
llvm::SmallVector<mlir::Value> result;
4548+
unsigned ii = 0;
4549+
for (const auto &ss : subs)
4550+
std::visit(
4551+
Fortran::common::visitors{
4552+
[&](const Fortran::evaluate::
4553+
IndirectSubscriptIntegerExpr &intExpr) {
4554+
if (intExpr.value().Rank() > 0)
4555+
result.push_back(builder.createConvert(
4556+
loc, idxTy, desShape[ii++]));
4557+
},
4558+
[&](const Fortran::evaluate::Triplet &triple) {
4559+
// This is a triple which may be using an
4560+
// explicit control variable.
4561+
auto ou = triple.upper();
4562+
auto up = builder.createConvert(
4563+
loc, idxTy,
4564+
ou.has_value() ? fir::getBase(asScalar(*ou))
4565+
: desShape[ii]);
4566+
auto ol = triple.lower();
4567+
auto lo =
4568+
ol.has_value()
4569+
? builder.createConvert(
4570+
loc, idxTy, fir::getBase(asScalar(*ol)))
4571+
: builder.createIntegerConstant(loc, idxTy,
4572+
1);
4573+
auto step = builder.createConvert(
4574+
loc, idxTy,
4575+
fir::getBase(asScalar(triple.stride())));
4576+
auto diff = builder.create<mlir::SubIOp>(loc, up, lo);
4577+
auto sum =
4578+
builder.create<mlir::AddIOp>(loc, diff, step);
4579+
mlir::Value count =
4580+
builder.create<mlir::SignedDivIOp>(loc, sum,
4581+
step);
4582+
result.push_back(count);
4583+
ii++;
4584+
}},
4585+
ss.u);
4586+
return result;
4587+
};
44564588
}
44574589
}
44584590
}
@@ -4519,6 +4651,7 @@ class ArrayExprLowering {
45194651
}
45204652
return newIters;
45214653
};
4654+
auto ty = fir::unwrapSequenceType(unwrapBoxEleTy(arrTy));
45224655
return RaiseRT{[=](IterSpace iters) { return lambda(pc(iters)); }, ty,
45234656
prevRanked, ranked};
45244657
}
@@ -4605,7 +4738,9 @@ class ArrayExprLowering {
46054738

46064739
static mlir::Type adjustedArraySubtype(mlir::Type ty,
46074740
mlir::ValueRange indices) {
4608-
return adjustedArrayElementType(fir::applyPathToType(ty, indices));
4741+
auto pathTy = fir::applyPathToType(ty, indices);
4742+
assert(pathTy && "indices failed to apply to type");
4743+
return adjustedArrayElementType(pathTy);
46094744
}
46104745

46114746
/// Build an ExtendedValue from a fir.array<?x...?xT> without actually
@@ -5486,7 +5621,7 @@ class ArrayExprLowering {
54865621
/// Even in an explicitly defined iteration space, one can have an
54875622
/// assignment with rank > 0 and thus an implied shape on a component in the
54885623
/// path.
5489-
llvm::SmallVector<mlir::Value> explicitImpliedShape;
5624+
llvm::Optional<LBC> explicitImpliedLoopBounds;
54905625
Fortran::lower::ImplicitIterSpace *masks = nullptr;
54915626
ConstituentSemantics semant = ConstituentSemantics::RefTransparent;
54925627
bool inSlice = false;

0 commit comments

Comments
 (0)