Skip to content

Commit 126aa69

Browse files
committed
Fix evaluation semantics of FORALL constructs per 10.2.4.2.4.
When lowering of FORALL was refactored to split the work between the bridge and array expression lowering, the individual statements in a forall construct body were no longer being lowered one at a time. These changes add back the proper lowering. Changes include: - Factoring the rhs array base analysis to be by statement. - Fix bug with finalization of context stack. - Fix bug with rerunning the analysis. - Merge aspects of cleanup code gen. - Restructure loop nest lowering such that each assignment resides in its own copy of a loop nest. - Add a lazy shape buffer. Thread lazy mask buffers through the loop nest so that cached results are available. - Regenerate checks for tests. - Fixes for p9. Make forall-2.f90 source file name independent. Fix for k6 test, among others. Using incorrect statement context. Fix bug exposed by m7. Fix m7 bug. review comments
1 parent fbb3d2e commit 126aa69

File tree

9 files changed

+944
-664
lines changed

9 files changed

+944
-664
lines changed

flang/include/flang/Lower/ConvertExpr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ createSomeArrayTempValue(AbstractConverter &converter,
176176
fir::ExtendedValue
177177
createLazyArrayTempValue(AbstractConverter &converter,
178178
const evaluate::Expr<evaluate::SomeType> &expr,
179-
mlir::Value var, SymMap &symMap,
180-
StatementContext &stmtCtx);
179+
mlir::Value var, mlir::Value shapeBuffer,
180+
SymMap &symMap, StatementContext &stmtCtx);
181181

182182
/// Lower an array expression to a value of type box. The expression must be a
183183
/// variable.

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ class FirOpBuilder : public mlir::OpBuilder {
248248
}
249249

250250
/// Construct one of the two forms of shape op from an array box.
251-
mlir::Value consShape(mlir::Location loc, const fir::AbstractArrayBox &arr);
252-
mlir::Value consShape(mlir::Location loc, llvm::ArrayRef<mlir::Value> shift,
251+
mlir::Value genShape(mlir::Location loc, const fir::AbstractArrayBox &arr);
252+
mlir::Value genShape(mlir::Location loc, llvm::ArrayRef<mlir::Value> shift,
253253
llvm::ArrayRef<mlir::Value> exts);
254-
mlir::Value consShape(mlir::Location loc, llvm::ArrayRef<mlir::Value> exts);
254+
mlir::Value genShape(mlir::Location loc, llvm::ArrayRef<mlir::Value> exts);
255255

256256
/// Create one of the shape ops given an extended value. For a boxed value,
257257
/// this may create a `fir.shift` op.

flang/lib/Lower/Bridge.cpp

Lines changed: 115 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,80 +1211,123 @@ class FirConverter : public Fortran::lower::AbstractConverter {
12111211
/// Process a concurrent header for a FORALL. (Concurrent headers for DO
12121212
/// CONCURRENT loops are lowered elsewhere.)
12131213
void genFIR(const Fortran::parser::ConcurrentHeader &header) {
1214-
// Create our iteration space from the header spec.
1215-
localSymbols.pushScope();
1216-
auto idxTy = builder->getIndexType();
1217-
auto loc = toLocation();
1218-
llvm::SmallVector<fir::DoLoopOp> loops;
1219-
for (auto &ctrl :
1220-
std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t)) {
1221-
const auto *ctrlVar = std::get<Fortran::parser::Name>(ctrl.t).symbol;
1222-
const auto *lo = Fortran::semantics::GetExpr(std::get<1>(ctrl.t));
1223-
const auto *hi = Fortran::semantics::GetExpr(std::get<2>(ctrl.t));
1224-
auto &optStep =
1225-
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(ctrl.t);
1226-
auto lb = builder->createConvert(
1227-
loc, idxTy,
1228-
fir::getBase(genExprValue(*lo, explicitIterSpace.stmtContext())));
1229-
auto ub = builder->createConvert(
1230-
loc, idxTy,
1231-
fir::getBase(genExprValue(*hi, explicitIterSpace.stmtContext())));
1232-
auto by = optStep.has_value()
1233-
? builder->createConvert(
1234-
loc, idxTy,
1235-
fir::getBase(genExprValue(
1236-
*Fortran::semantics::GetExpr(*optStep),
1237-
explicitIterSpace.stmtContext())))
1238-
: builder->createIntegerConstant(loc, idxTy, 1);
1239-
auto lp = builder->create<fir::DoLoopOp>(
1240-
loc, lb, ub, by, /*unordered=*/true,
1241-
/*finalCount=*/false, explicitIterSpace.getInnerArgs());
1242-
if (!loops.empty())
1243-
builder->create<fir::ResultOp>(loc, lp.getResults());
1244-
explicitIterSpace.setInnerArgs(lp.getRegionIterArgs());
1245-
builder->setInsertionPointToStart(lp.getBody());
1246-
forceControlVariableBinding(ctrlVar, lp.getInductionVar());
1247-
loops.push_back(lp);
1248-
}
1249-
explicitIterSpace.setOuterLoop(loops[0]);
1250-
if (const auto &mask =
1251-
std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(
1252-
header.t);
1253-
mask.has_value()) {
1254-
auto i1Ty = builder->getI1Type();
1255-
auto maskExv = genExprValue(*Fortran::semantics::GetExpr(mask.value()),
1256-
explicitIterSpace.stmtContext());
1257-
auto cond = builder->createConvert(loc, i1Ty, fir::getBase(maskExv));
1258-
auto ifOp = builder->create<fir::IfOp>(
1259-
loc, explicitIterSpace.innerArgTypes(), cond,
1260-
/*withElseRegion=*/true);
1261-
builder->create<fir::ResultOp>(loc, ifOp.getResults());
1262-
builder->setInsertionPointToStart(&ifOp.elseRegion().front());
1263-
builder->create<fir::ResultOp>(loc, explicitIterSpace.getInnerArgs());
1264-
builder->setInsertionPointToStart(&ifOp.thenRegion().front());
1214+
llvm::SmallVector<mlir::Value> lows;
1215+
llvm::SmallVector<mlir::Value> highs;
1216+
llvm::SmallVector<mlir::Value> steps;
1217+
if (explicitIterSpace.isOutermostForall()) {
1218+
// For the outermost forall, we evaluate the bounds expressions once.
1219+
// Contrastingly, if this forall is nested, the bounds expressions are
1220+
// assumed to be pure, possibly dependent on outer concurrent control
1221+
// variables, possibly variant with respect to arguments, and will be
1222+
// re-evaluated.
1223+
auto loc = toLocation();
1224+
auto idxTy = builder->getIndexType();
1225+
auto &stmtCtx = explicitIterSpace.stmtContext();
1226+
auto lowerExpr = [&](auto &e) {
1227+
return fir::getBase(genExprValue(e, stmtCtx));
1228+
};
1229+
for (auto &ctrl :
1230+
std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t)) {
1231+
const auto *lo = Fortran::semantics::GetExpr(std::get<1>(ctrl.t));
1232+
const auto *hi = Fortran::semantics::GetExpr(std::get<2>(ctrl.t));
1233+
auto &optStep =
1234+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(ctrl.t);
1235+
lows.push_back(builder->createConvert(loc, idxTy, lowerExpr(*lo)));
1236+
highs.push_back(builder->createConvert(loc, idxTy, lowerExpr(*hi)));
1237+
steps.push_back(
1238+
optStep.has_value()
1239+
? builder->createConvert(
1240+
loc, idxTy,
1241+
lowerExpr(*Fortran::semantics::GetExpr(*optStep)))
1242+
: builder->createIntegerConstant(loc, idxTy, 1));
1243+
}
12651244
}
1245+
auto lambda = [&, lows, highs, steps]() {
1246+
// Create our iteration space from the header spec.
1247+
auto loc = toLocation();
1248+
auto idxTy = builder->getIndexType();
1249+
llvm::SmallVector<fir::DoLoopOp> loops;
1250+
auto &stmtCtx = explicitIterSpace.stmtContext();
1251+
auto lowerExpr = [&](auto &e) {
1252+
return fir::getBase(genExprValue(e, stmtCtx));
1253+
};
1254+
const auto outermost = !lows.empty();
1255+
std::size_t headerIndex = 0;
1256+
for (auto &ctrl :
1257+
std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t)) {
1258+
const auto *ctrlVar = std::get<Fortran::parser::Name>(ctrl.t).symbol;
1259+
mlir::Value lb;
1260+
mlir::Value ub;
1261+
mlir::Value by;
1262+
if (outermost) {
1263+
assert(headerIndex < lows.size());
1264+
lb = lows[headerIndex];
1265+
ub = highs[headerIndex];
1266+
by = steps[headerIndex++];
1267+
} else {
1268+
const auto *lo = Fortran::semantics::GetExpr(std::get<1>(ctrl.t));
1269+
const auto *hi = Fortran::semantics::GetExpr(std::get<2>(ctrl.t));
1270+
auto &optStep =
1271+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(ctrl.t);
1272+
lb = builder->createConvert(loc, idxTy, lowerExpr(*lo));
1273+
ub = builder->createConvert(loc, idxTy, lowerExpr(*hi));
1274+
by = optStep.has_value()
1275+
? builder->createConvert(
1276+
loc, idxTy,
1277+
lowerExpr(*Fortran::semantics::GetExpr(*optStep)))
1278+
: builder->createIntegerConstant(loc, idxTy, 1);
1279+
}
1280+
auto lp = builder->create<fir::DoLoopOp>(
1281+
loc, lb, ub, by, /*unordered=*/true,
1282+
/*finalCount=*/false, explicitIterSpace.getInnerArgs());
1283+
if (!loops.empty() || !outermost)
1284+
builder->create<fir::ResultOp>(loc, lp.getResults());
1285+
explicitIterSpace.setInnerArgs(lp.getRegionIterArgs());
1286+
builder->setInsertionPointToStart(lp.getBody());
1287+
forceControlVariableBinding(ctrlVar, lp.getInductionVar());
1288+
loops.push_back(lp);
1289+
}
1290+
explicitIterSpace.setOuterLoop(loops[0]);
1291+
if (const auto &mask =
1292+
std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(
1293+
header.t);
1294+
mask.has_value()) {
1295+
auto i1Ty = builder->getI1Type();
1296+
auto maskExv =
1297+
genExprValue(*Fortran::semantics::GetExpr(mask.value()), stmtCtx);
1298+
auto cond = builder->createConvert(loc, i1Ty, fir::getBase(maskExv));
1299+
auto ifOp = builder->create<fir::IfOp>(
1300+
loc, explicitIterSpace.innerArgTypes(), cond,
1301+
/*withElseRegion=*/true);
1302+
builder->create<fir::ResultOp>(loc, ifOp.getResults());
1303+
builder->setInsertionPointToStart(&ifOp.elseRegion().front());
1304+
builder->create<fir::ResultOp>(loc, explicitIterSpace.getInnerArgs());
1305+
builder->setInsertionPointToStart(&ifOp.thenRegion().front());
1306+
}
1307+
};
1308+
// Push the lambda to gen the loop nest context.
1309+
explicitIterSpace.pushLoopNest(lambda);
12661310
}
12671311

12681312
void genFIR(const Fortran::parser::ForallAssignmentStmt &stmt) {
12691313
std::visit([&](const auto &x) { genFIR(x); }, stmt.u);
12701314
}
12711315

12721316
void genFIR(const Fortran::parser::EndForallStmt &) {
1273-
explicitIterSpace.finalize();
12741317
cleanupExplicitSpace();
12751318
}
12761319

12771320
template <typename A>
12781321
void prepareExplicitSpace(const A &forall) {
1279-
analyzeExplicitSpace(forall);
1322+
if (!explicitIterSpace.isActive())
1323+
analyzeExplicitSpace(forall);
1324+
localSymbols.pushScope();
12801325
explicitIterSpace.enter();
1281-
Fortran::lower::createArrayLoads(*this, explicitIterSpace, localSymbols);
12821326
}
12831327

12841328
/// Cleanup all the FORALL context information when we exit.
12851329
void cleanupExplicitSpace() {
1286-
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
1287-
explicitIterSpace.conditionalCleanup();
1330+
explicitIterSpace.leave();
12881331
localSymbols.popScope();
12891332
}
12901333

@@ -1824,6 +1867,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
18241867
void genAssignment(const Fortran::evaluate::Assignment &assign) {
18251868
Fortran::lower::StatementContext stmtCtx;
18261869
auto loc = toLocation();
1870+
if (explicitIterationSpace()) {
1871+
Fortran::lower::createArrayLoads(*this, explicitIterSpace, localSymbols);
1872+
explicitIterSpace.genLoopNest();
1873+
}
18271874
std::visit(
18281875
Fortran::common::visitors{
18291876
// [1] Plain old assignment.
@@ -1920,7 +1967,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19201967
if (implicitIterationSpace())
19211968
TODO(loc, "user defined assignment within WHERE");
19221969
Fortran::semantics::SomeExpr expr{procRef};
1923-
createFIRExpr(toLocation(), &expr, stmtCtx);
1970+
createFIRExpr(toLocation(), &expr,
1971+
explicitIterationSpace()
1972+
? explicitIterSpace.stmtContext()
1973+
: stmtCtx);
19241974
},
19251975

19261976
// [3] Pointer assignment with possibly empty bounds-spec. R1035: a
@@ -1981,6 +2031,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19812031
},
19822032
},
19832033
assign.u);
2034+
if (explicitIterationSpace())
2035+
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
19842036
}
19852037

19862038
void genFIR(const Fortran::parser::WhereConstruct &c) {
@@ -2563,6 +2615,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25632615
void analyzeExplicitSpace(const Fortran::evaluate::Assignment *assign) {
25642616
analyzeExplicitSpace</*LHS=*/true>(assign->lhs);
25652617
analyzeExplicitSpace(assign->rhs);
2618+
explicitIterSpace.endAssign();
25662619
}
25672620
void analyzeExplicitSpace(const Fortran::parser::ForallAssignmentStmt &stmt) {
25682621
std::visit([&](const auto &s) { analyzeExplicitSpace(s); }, stmt.u);
@@ -2693,7 +2746,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26932746
auto var = builder->createTemporary(loc, ty);
26942747
auto nil = builder->createNullConstant(loc, ty);
26952748
builder->create<fir::StoreOp>(loc, nil, var);
2696-
implicitIterSpace.addMaskVariable(exp, var);
2749+
auto shTy = fir::HeapType::get(builder->getIndexType());
2750+
auto shape = builder->createTemporary(loc, shTy);
2751+
auto nilSh = builder->createNullConstant(loc, shTy);
2752+
builder->create<fir::StoreOp>(loc, nilSh, shape);
2753+
implicitIterSpace.addMaskVariable(exp, var, shape);
26972754
explicitIterSpace.outermostContext().attachCleanup([=]() {
26982755
auto load = builder->create<fir::LoadOp>(loc, var);
26992756
auto cmp = builder->genIsNotNull(loc, load);

0 commit comments

Comments
 (0)