Skip to content

Commit a812601

Browse files
committed
Fixes a bug with vector subscripts.
Was lowering a vector subscript to the wrong scalar code instead of using the contents of the vector.
1 parent d61310e commit a812601

File tree

2 files changed

+159
-162
lines changed

2 files changed

+159
-162
lines changed

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 67 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -2407,7 +2407,7 @@ class ScalarExprLowering {
24072407
// Helper for changing the semantics in a given context. Preserves the current
24082408
// semantics which is resumed when the "push" goes out of scope.
24092409
#define PushSemantics(PushVal) \
2410-
[[maybe_unused]] auto pushSemanticsLocalVariable97201 = \
2410+
[[maybe_unused]] auto pushSemanticsLocalVariable##__LINE__ = \
24112411
Fortran::common::ScopedSet(semant, PushVal);
24122412

24132413
static bool isAdjustedArrayElementType(mlir::Type t) {
@@ -3231,65 +3231,6 @@ namespace {
32313231
class ArrayExprLowering {
32323232
using ExtValue = fir::ExtendedValue;
32333233

3234-
struct IterationSpace {
3235-
IterationSpace() = default;
3236-
3237-
template <typename A>
3238-
explicit IterationSpace(mlir::Value inArg, mlir::Value outRes,
3239-
llvm::iterator_range<A> range)
3240-
: inArg{inArg}, outRes{outRes}, indices{range.begin(), range.end()} {}
3241-
3242-
explicit IterationSpace(const IterationSpace &from,
3243-
llvm::ArrayRef<mlir::Value> idxs)
3244-
: inArg(from.inArg), outRes(from.outRes), element(from.element),
3245-
indices(idxs.begin(), idxs.end()) {}
3246-
3247-
bool empty() const { return indices.empty(); }
3248-
mlir::Value innerArgument() const { return inArg; }
3249-
mlir::Value outerResult() const { return outRes; }
3250-
llvm::SmallVector<mlir::Value> iterVec() const { return indices; }
3251-
mlir::Value iterValue(std::size_t i) const {
3252-
assert(i < indices.size());
3253-
return indices[i];
3254-
}
3255-
3256-
/// Set (rewrite) the Value at a given index.
3257-
void setIndexValue(std::size_t i, mlir::Value v) {
3258-
assert(i < indices.size());
3259-
indices[i] = v;
3260-
}
3261-
3262-
void setIndexValues(llvm::ArrayRef<mlir::Value> vals) {
3263-
indices.assign(vals.begin(), vals.end());
3264-
}
3265-
3266-
void insertIndexValue(std::size_t i, mlir::Value av) {
3267-
assert(i <= indices.size());
3268-
indices.insert(indices.begin() + i, av);
3269-
}
3270-
3271-
/// Set the `element` value. This is the SSA value that corresponds to an
3272-
/// element of the resultant array value.
3273-
void setElement(ExtValue &&ele) {
3274-
assert(!fir::getBase(element) && "result element already set");
3275-
element = ele;
3276-
}
3277-
3278-
/// Get the value that will be merged into the resultant array. This is the
3279-
/// computed value that will be stored to the lhs of the assignment.
3280-
mlir::Value getElement() const {
3281-
assert(fir::getBase(element) && "element must be set");
3282-
return fir::getBase(element);
3283-
}
3284-
ExtValue elementExv() const { return element; }
3285-
3286-
private:
3287-
mlir::Value inArg;
3288-
mlir::Value outRes;
3289-
ExtValue element;
3290-
llvm::SmallVector<mlir::Value> indices;
3291-
};
3292-
32933234
/// Structure to keep track of lowered array operands in the
32943235
/// array expression. Useful to later deduce the shape of the
32953236
/// array expression.
@@ -3311,10 +3252,13 @@ class ArrayExprLowering {
33113252
EndOfSubscripts, ImplicitSubscripts>;
33123253

33133254
/// Active iteration space.
3314-
using IterSpace = const IterationSpace &;
3255+
using IterationSpace = Fortran::lower::IterationSpace;
3256+
using IterSpace = const Fortran::lower::IterationSpace &;
3257+
33153258
/// Current continuation. Function that will generate IR for a single
33163259
/// iteration of the pending iterative loop structure.
3317-
using CC = std::function<ExtValue(IterSpace)>;
3260+
using CC = Fortran::lower::GenerateElementalArrayFunc;
3261+
33183262
/// Projection continuation. Function that will project one iteration space
33193263
/// into another.
33203264
using PC = std::function<IterationSpace(IterSpace)>;
@@ -3957,39 +3901,22 @@ class ArrayExprLowering {
39573901
return implicitSpace && !implicitSpace->empty();
39583902
}
39593903

3960-
void addMaskRebind(Fortran::lower::FrontEndExpr e, mlir::Value var,
3961-
mlir::Value shapeBuffer, ExtValue tmp) {
3962-
// After this statement is completed, rebind the mask expression to some
3963-
// code that loads the mask result from the variable that was initialized
3964-
// lazily.
3965-
explicitSpace->attachLoopCleanup([e, implicit = implicitSpace,
3966-
loc = getLoc(), shapeBuffer,
3967-
size = tmp.rank(),
3968-
var](fir::FirOpBuilder &builder) {
3969-
auto load = builder.create<fir::LoadOp>(loc, var);
3970-
auto eleTy = fir::unwrapSequenceType(fir::unwrapRefType(load.getType()));
3971-
auto seqTy = fir::SequenceType::get(eleTy, size);
3972-
auto toTy = fir::HeapType::get(seqTy);
3973-
auto base = builder.createConvert(loc, toTy, load);
3974-
llvm::SmallVector<mlir::Value> shapeVec;
3975-
auto idxTy = builder.getIndexType();
3976-
auto refIdxTy = builder.getRefType(idxTy);
3977-
auto shEleTy = fir::unwrapSequenceType(
3978-
fir::unwrapRefType(fir::unwrapRefType(shapeBuffer.getType())));
3979-
// Cast shape array to the correct 1-D array with constant extent.
3980-
fir::SequenceType::Shape dim = {
3981-
static_cast<fir::SequenceType::Extent>(size)};
3982-
auto buffTy = builder.getRefType(fir::SequenceType::get(dim, shEleTy));
3983-
auto buffer = builder.createConvert(loc, buffTy, shapeBuffer);
3984-
for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i) {
3985-
auto offset = builder.createIntegerConstant(loc, idxTy, i);
3986-
auto ele =
3987-
builder.create<fir::CoordinateOp>(loc, refIdxTy, buffer, offset);
3988-
shapeVec.push_back(builder.create<fir::LoadOp>(loc, ele));
3989-
}
3990-
auto shape = builder.genShape(loc, shapeVec);
3991-
implicit->replaceBinding(e, base, shape);
3992-
});
3904+
CC genMaskAccess(mlir::Value tmp, mlir::Value shape) {
3905+
auto loc = getLoc();
3906+
return [=, builder = &converter.getFirOpBuilder()](IterSpace iters) {
3907+
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(tmp.getType());
3908+
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
3909+
auto eleRefTy = builder->getRefType(eleTy);
3910+
auto i1Ty = builder->getI1Type();
3911+
// Adjust indices for any shift of the origin of the array.
3912+
auto indices = fir::factory::originateIndices(
3913+
loc, *builder, tmp.getType(), shape, iters.iterVec());
3914+
auto addr = builder->create<fir::ArrayCoorOp>(
3915+
loc, eleRefTy, tmp, shape, /*slice=*/mlir::Value{}, indices,
3916+
/*typeParams=*/llvm::None);
3917+
auto load = builder->create<fir::LoadOp>(loc, addr);
3918+
return builder->createConvert(loc, i1Ty, load);
3919+
};
39933920
}
39943921

39953922
/// Construct the incremental instantiations of the ragged array structure.
@@ -4050,7 +3977,7 @@ class ArrayExprLowering {
40503977
auto hdrSh = builder.create<fir::CoordinateOp>(loc, coorTy2, header, two);
40513978
auto shapePtr = builder.create<fir::LoadOp>(loc, hdrSh);
40523979
// Replace the binding.
4053-
implicitSpace->replaceBinding(expr, inVar, shapePtr);
3980+
implicitSpace->rebind(expr, genMaskAccess(inVar, shapePtr));
40543981
if (i < depth - 1)
40553982
builder.restoreInsertionPoint(insPt);
40563983
}
@@ -4086,7 +4013,7 @@ class ArrayExprLowering {
40864013
auto tmp = Fortran::lower::createSomeArrayTempValue(converter, *e,
40874014
symMap, stmtCtx);
40884015
auto shape = builder.createShape(loc, tmp);
4089-
implicitSpace->bind(e, fir::getBase(tmp), shape);
4016+
implicitSpace->bind(e, genMaskAccess(fir::getBase(tmp), shape));
40904017
}
40914018

40924019
// Set buffer from the header.
@@ -4129,7 +4056,7 @@ class ArrayExprLowering {
41294056
auto shapeOp = builder.genShape(loc, extents);
41304057

41314058
// Replace binding with the local result.
4132-
implicitSpace->replaceBinding(e, buff, shapeOp);
4059+
implicitSpace->rebind(e, genMaskAccess(buff, shapeOp));
41334060
}
41344061
}
41354062
}
@@ -4227,22 +4154,8 @@ class ArrayExprLowering {
42274154
// the innermost loop.
42284155
if (implicitSpaceHasMasks()) {
42294156
// Recover the cached condition from the mask buffer.
4230-
auto genCond = [&](Fortran::lower::MaskAddrAndShape &&mask,
4231-
IterSpace iters) {
4232-
auto tmp = mask.first;
4233-
auto shape = mask.second;
4234-
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(tmp.getType());
4235-
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
4236-
auto eleRefTy = builder.getRefType(eleTy);
4237-
auto i1Ty = builder.getI1Type();
4238-
// Adjust indices for any shift of the origin of the array.
4239-
auto indices = fir::factory::originateIndices(
4240-
loc, builder, tmp.getType(), shape, iters.iterVec());
4241-
auto addr = builder.create<fir::ArrayCoorOp>(
4242-
loc, eleRefTy, tmp, shape, /*slice=*/mlir::Value{}, indices,
4243-
/*typeParams=*/llvm::None);
4244-
auto load = builder.create<fir::LoadOp>(loc, addr);
4245-
return builder.createConvert(loc, i1Ty, load);
4157+
auto genCond = [&](Fortran::lower::FrontEndExpr e, IterSpace iters) {
4158+
return implicitSpace->getBoundClosure(e)(iters);
42464159
};
42474160

42484161
// Handle the negated conditions in topological order of the WHERE
@@ -4269,13 +4182,11 @@ class ArrayExprLowering {
42694182
};
42704183
for (std::remove_const_t<decltype(size)> i = 0; i < size; ++i)
42714184
if (const auto *e = maskExprs[i])
4272-
genFalseBlock(
4273-
e, genCond(implicitSpace->getBindingWithShape(e), iters));
4185+
genFalseBlock(e, genCond(e, iters));
42744186

42754187
// The last condition is either non-negated or unconditionally negated.
42764188
if (const auto *e = maskExprs[size])
4277-
genTrueBlock(e,
4278-
genCond(implicitSpace->getBindingWithShape(e), iters));
4189+
genTrueBlock(e, genCond(e, iters));
42794190
}
42804191
}
42814192

@@ -5035,6 +4946,24 @@ class ArrayExprLowering {
50354946
return helper(x);
50364947
}
50374948

4949+
template <typename A>
4950+
CC genVectorSubscriptArrayFetch(const A &expr) {
4951+
PushSemantics(ConstituentSemantics::RefTransparent);
4952+
auto saved = Fortran::common::ScopedSet(explicitSpace, nullptr);
4953+
return genarr(expr);
4954+
}
4955+
4956+
/// Generate an access by vector subscript using the index in the iteration
4957+
/// vector at `dim`.
4958+
mlir::Value genAccessByVector(mlir::Location loc, CC genArrFetch,
4959+
IterSpace iters, std::size_t dim) {
4960+
IterationSpace vecIters(iters,
4961+
llvm::ArrayRef<mlir::Value>{iters.iterValue(dim)});
4962+
auto fetch = genArrFetch(vecIters);
4963+
auto idxTy = builder.getIndexType();
4964+
return builder.createConvert(loc, idxTy, fir::getBase(fetch));
4965+
}
4966+
50384967
/// When we have an array reference, the expressions specified in each
50394968
/// dimension may be slice operations (e.g. `i:j:k`), vectors, or simple
50404969
/// (loop-invarianet) scalar expressions. This returns the base entity, the
@@ -5081,25 +5010,20 @@ class ArrayExprLowering {
50815010
auto base = x.base();
50825011
auto exv = genArrayBase(base);
50835012
auto arrExpr = ignoreEvConvert(e);
5084-
auto saveSemant = semant;
5085-
semant = ConstituentSemantics::RefTransparent;
5086-
auto genArrFetch = genarr(arrExpr);
5087-
semant = saveSemant;
5013+
auto genArrFetch = genVectorSubscriptArrayFetch(arrExpr);
50885014
auto currentPC = pc;
50895015
auto dim = sub.index();
50905016
auto lb =
50915017
fir::factory::readLowerBound(builder, loc, exv, dim, one);
50925018
pc = [=](IterSpace iters) {
50935019
IterationSpace newIters = currentPC(iters);
5094-
IterationSpace vecIters(
5095-
newIters,
5096-
llvm::ArrayRef<mlir::Value>{newIters.iterValue(dim)});
5097-
auto fetch = genArrFetch(vecIters);
5098-
auto cast =
5099-
builder.createConvert(loc, idxTy, fir::getBase(fetch));
5100-
auto val = builder.create<mlir::arith::SubIOp>(loc, idxTy,
5101-
cast, lb);
5102-
newIters.setIndexValue(dim, val);
5020+
auto val =
5021+
genAccessByVector(loc, genArrFetch, newIters, dim);
5022+
// Value read from vector subscript array and normalized
5023+
// using the base array's lower bound value.
5024+
auto origin = builder.create<mlir::arith::SubIOp>(
5025+
loc, idxTy, val, lb);
5026+
newIters.setIndexValue(dim, origin);
51035027
return newIters;
51045028
};
51055029
// Create a slice with the vector size so that the shape
@@ -6122,7 +6046,8 @@ class ArrayExprLowering {
61226046
std::tuple<llvm::SmallVector<mlir::Value>, mlir::Type,
61236047
llvm::SmallVector<mlir::Value>>
61246048
lowerPath(mlir::Location loc, llvm::ArrayRef<PathComponent> revPath,
6125-
mlir::Type ty, IterSpace iters) {
6049+
fir::ArrayLoadOp arrLd, IterSpace iters) {
6050+
mlir::Type ty = arrLd.getType();
61266051
auto fieldTy = fir::FieldType::get(builder.getContext());
61276052
auto idxTy = builder.getIndexType();
61286053
llvm::SmallVector<mlir::Value> result;
@@ -6139,14 +6064,15 @@ class ArrayExprLowering {
61396064
return memTy;
61406065
};
61416066
auto addSub = [&](const Fortran::evaluate::Subscript &sub) {
6142-
auto exv = std::visit(
6067+
auto indexValue = std::visit(
61436068
Fortran::common::visitors{
6144-
[&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &e)
6145-
-> mlir::Value {
6069+
[&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &e) {
61466070
if (e.value().Rank() == 0)
61476071
return fir::getBase(asScalarArray(e.value()));
61486072
dim++;
6149-
return fir::getBase(genarr(e.value())(iters));
6073+
auto arrExpr = ignoreEvConvert(e.value());
6074+
auto genArrFetch = genVectorSubscriptArrayFetch(arrExpr);
6075+
return genAccessByVector(loc, genArrFetch, iters, dim);
61506076
},
61516077
[&](const Fortran::evaluate::Triplet &t) -> mlir::Value {
61526078
auto impliedIter = iters.iterValue(dim++);
@@ -6161,12 +6087,10 @@ class ArrayExprLowering {
61616087
auto step = builder.createConvert(loc, idxTy, stride);
61626088
auto prod =
61636089
builder.create<mlir::arith::MulIOp>(loc, impliedIter, step);
6164-
auto trip =
6165-
builder.create<mlir::arith::AddIOp>(loc, initial, prod);
6166-
return trip;
6090+
return builder.create<mlir::arith::AddIOp>(loc, initial, prod);
61676091
}},
61686092
sub.u);
6169-
result.push_back(builder.createConvert(loc, idxTy, fir::getBase(exv)));
6093+
result.push_back(builder.createConvert(loc, idxTy, indexValue));
61706094
};
61716095
auto pushAllIters = [&]() {
61726096
// FIXME: Need to handle user-defined lower bound. Assume it is the
@@ -6217,7 +6141,7 @@ class ArrayExprLowering {
62176141
return [=, esp = this->explicitSpace](IterSpace iters) mutable {
62186142
auto innerArg = esp->findArgumentOfLoad(load);
62196143
auto [path, eleTy, substringBounds] =
6220-
lowerPath(loc, revPath, load.getType(), iters);
6144+
lowerPath(loc, revPath, load, iters);
62216145
if (isAdjustedArrayElementType(eleTy)) {
62226146
auto eleRefTy = builder.getRefType(eleTy);
62236147
auto arrayOp = builder.create<fir::ArrayAccessOp>(
@@ -6257,7 +6181,7 @@ class ArrayExprLowering {
62576181
destination = load;
62586182
auto innerArg = explicitSpace->findArgumentOfLoad(load);
62596183
return [=](IterSpace iters) mutable {
6260-
auto [path, eleTy, _] = lowerPath(loc, revPath, load.getType(), iters);
6184+
auto [path, eleTy, _] = lowerPath(loc, revPath, load, iters);
62616185
auto refEleTy =
62626186
fir::isa_ref_type(eleTy) ? eleTy : builder.getRefType(eleTy);
62636187
auto arrModify = builder.create<fir::ArrayModifyOp>(
@@ -6272,7 +6196,7 @@ class ArrayExprLowering {
62726196
}
62736197
return [=](IterSpace iters) mutable {
62746198
auto [path, eleTy, substringBounds] =
6275-
lowerPath(loc, revPath, load.getType(), iters);
6199+
lowerPath(loc, revPath, load, iters);
62766200
if (semant == ConstituentSemantics::RefOpaque ||
62776201
isAdjustedArrayElementType(eleTy)) {
62786202
auto resTy = builder.getRefType(eleTy);

0 commit comments

Comments
 (0)