Skip to content

Commit 124bba4

Browse files
committed
Implement the correct splitting of FORALL and array expression loops
such that the WHERE conditional value buffer is lowered lazily and threaded forward into the ensuing loops properly. This implements a structure that will evaluate the WHERE conditions within the FORALL context, save the results, and propagate those cached results to the subsequent loop nests which evaluate the array expression assignments under the WHERE guards. These changes do not properly account for dynamic reshaping of the WHERE mask buffer if it is parametric on the enclosing FORALL concurrent header variables.
1 parent 982e852 commit 124bba4

File tree

8 files changed

+808
-677
lines changed

8 files changed

+808
-677
lines changed

flang/include/flang/Lower/ConvertExpr.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#ifndef FORTRAN_LOWER_CONVERTEXPR_H
1818
#define FORTRAN_LOWER_CONVERTEXPR_H
1919

20-
#include "flang/Evaluate/shape.h"
20+
#include "flang/Evaluate/expression.h"
2121
#include "flang/Optimizer/Builder/BoxValue.h"
2222
#include "flang/Optimizer/Builder/FIRBuilder.h"
2323

@@ -172,7 +172,9 @@ createSomeArrayTempValue(AbstractConverter &converter,
172172
/// Like createSomeArrayTempValue, but the temporary buffer is allocated lazily
173173
/// (inside the loops instead of before the loops). This can be useful if a
174174
/// loop's bounds are functions of other loop indices, for example.
175-
fir::ExtendedValue
175+
std::pair<fir::ExtendedValue,
176+
std::function<
177+
std::pair<fir::ExtendedValue, mlir::Value>(fir::FirOpBuilder &)>>
176178
createLazyArrayTempValue(AbstractConverter &converter,
177179
const evaluate::Expr<evaluate::SomeType> &expr,
178180
mlir::Value var, mlir::Value shapeBuffer,

flang/lib/Lower/Bridge.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
12351235
mlir::Value by;
12361236
if (outermost) {
12371237
assert(headerIndex < lows.size());
1238+
if (headerIndex == 0)
1239+
explicitIterSpace.resetInnerArgs();
12381240
lb = lows[headerIndex];
12391241
ub = highs[headerIndex];
12401242
by = steps[headerIndex++];
@@ -1261,7 +1263,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
12611263
forceControlVariableBinding(ctrlVar, lp.getInductionVar());
12621264
loops.push_back(lp);
12631265
}
1264-
explicitIterSpace.setOuterLoop(loops[0]);
1266+
if (outermost)
1267+
explicitIterSpace.setOuterLoop(loops[0]);
12651268
if (const auto &mask =
12661269
std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(
12671270
header.t);
@@ -2730,13 +2733,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
27302733
auto nilSh = builder->createNullConstant(loc, shTy);
27312734
builder->create<fir::StoreOp>(loc, nilSh, shape);
27322735
implicitIterSpace.addMaskVariable(exp, var, shape);
2733-
explicitIterSpace.outermostContext().attachCleanup([=]() {
2734-
auto load = builder->create<fir::LoadOp>(loc, var);
2735-
auto cmp = builder->genIsNotNull(loc, load);
2736-
builder->genIfThen(loc, cmp)
2737-
.genThen([&]() { builder->create<fir::FreeMemOp>(loc, load); })
2738-
.end();
2739-
});
2736+
explicitIterSpace.outermostContext().attachCleanup(
2737+
[builder = this->builder, loc, var]() {
2738+
auto load = builder->create<fir::LoadOp>(loc, var);
2739+
auto cmp = builder->genIsNotNull(loc, load);
2740+
builder->genIfThen(loc, cmp)
2741+
.genThen([&]() { builder->create<fir::FreeMemOp>(loc, load); })
2742+
.end();
2743+
});
27402744
}
27412745

27422746
//===--------------------------------------------------------------------===//

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 125 additions & 51 deletions
Large diffs are not rendered by default.

flang/lib/Lower/IterationSpace.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,13 @@ class ExplicitIterSpace {
255255
innerArgs.push_back(arg);
256256
}
257257

258-
void setOuterLoop(fir::DoLoopOp loop) {
259-
if (!outerLoop.hasValue())
260-
outerLoop = loop;
261-
}
258+
/// Reset the outermost `array_load` arguments to the loop nest.
259+
void resetInnerArgs() { innerArgs = initialArgs; }
260+
261+
/// Capture the current outermost loop.
262+
void setOuterLoop(fir::DoLoopOp loop) { outerLoop = loop; }
262263

264+
/// Sets the inner loop argument at position \p offset to \p val.
263265
void setInnerArg(size_t offset, mlir::Value val) {
264266
assert(offset < innerArgs.size());
265267
innerArgs[offset] = val;
@@ -385,6 +387,7 @@ class ExplicitIterSpace {
385387
// Assignment statement context (inside the loop nest).
386388
StatementContext stmtCtx;
387389
llvm::SmallVector<mlir::Value> innerArgs;
390+
llvm::SmallVector<mlir::Value> initialArgs;
388391
llvm::Optional<fir::DoLoopOp> outerLoop;
389392
llvm::Optional<std::function<void(fir::FirOpBuilder &)>> loopCleanup;
390393
std::size_t forallContextOpen = 0;

flang/lib/Lower/SymbolMap.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ Fortran::lower::operator<<(llvm::raw_ostream &os,
7171
for (auto i : llvm::enumerate(symMap.symbolMapStack)) {
7272
os << " level " << i.index() << "<{\n";
7373
for (auto iter : i.value())
74-
os << " symbol [" << *iter.first << "] ->\n " << iter.second;
74+
os << " symbol @" << (void *)iter.first << " [" << *iter.first
75+
<< "] ->\n " << iter.second;
7576
os << " }>\n";
7677
}
7778
return os;

flang/test/Lower/forall-2.f90

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -80,39 +80,39 @@ subroutine slice_with_explicit_iters
8080
! CHECK: %[[VAL_8:.*]] = constant 1 : index
8181
! CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_1]], %[[VAL_2]] : (index, index) -> !fir.shape<2>
8282
! CHECK: %[[VAL_10:.*]] = fir.array_load %[[VAL_3]](%[[VAL_9]]) : (!fir.ref<!fir.array<10x10xi32>>, !fir.shape<2>) -> !fir.array<10x10xi32>
83-
! CHECK: %[[VAL_11:.*]] = fir.do_loop %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_7]] step %[[VAL_8]] unordered iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (!fir.array<10x10xi32>) {
84-
! CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_12]] : (index) -> i32
85-
! CHECK: fir.store %[[VAL_14]] to %[[VAL_0]] : !fir.ref<i32>
86-
! CHECK: %[[VAL_15:.*]] = constant 1 : i64
87-
! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
88-
! CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i32) -> i64
89-
! CHECK: %[[VAL_18:.*]] = constant 1 : i64
90-
! CHECK: %[[VAL_19:.*]] = fir.convert %[[VAL_18]] : (i64) -> index
91-
! CHECK: %[[VAL_20:.*]] = constant 0 : index
92-
! CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_15]] : (i64) -> index
93-
! CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_17]] : (i64) -> index
94-
! CHECK: %[[VAL_23:.*]] = subi %[[VAL_22]], %[[VAL_21]] : index
95-
! CHECK: %[[VAL_24:.*]] = addi %[[VAL_23]], %[[VAL_19]] : index
96-
! CHECK: %[[VAL_25:.*]] = divi_signed %[[VAL_24]], %[[VAL_19]] : index
97-
! CHECK: %[[VAL_26:.*]] = cmpi sgt, %[[VAL_25]], %[[VAL_20]] : index
98-
! CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[VAL_25]], %[[VAL_20]] : index
99-
! CHECK: %[[VAL_28:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
100-
! CHECK: %[[VAL_29:.*]] = constant 1 : index
101-
! CHECK: %[[VAL_30:.*]] = constant 0 : index
102-
! CHECK: %[[VAL_31:.*]] = subi %[[VAL_27]], %[[VAL_29]] : index
103-
! CHECK: %[[VAL_32:.*]] = fir.do_loop %[[VAL_33:.*]] = %[[VAL_30]] to %[[VAL_31]] step %[[VAL_29]] unordered iter_args(%[[VAL_34:.*]] = %[[VAL_10]]) -> (!fir.array<10x10xi32>) {
104-
! CHECK: %[[VAL_35:.*]] = constant 0 : i32
105-
! CHECK: %[[VAL_36:.*]] = subi %[[VAL_35]], %[[VAL_28]] : i32
106-
! CHECK: %[[VAL_37:.*]] = constant 1 : i64
107-
! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_37]] : (i64) -> index
108-
! CHECK: %[[VAL_39:.*]] = constant 1 : i64
109-
! CHECK: %[[VAL_40:.*]] = fir.convert %[[VAL_39]] : (i64) -> index
110-
! CHECK: %[[VAL_41:.*]] = muli %[[VAL_33]], %[[VAL_40]] : index
111-
! CHECK: %[[VAL_42:.*]] = addi %[[VAL_38]], %[[VAL_41]] : index
83+
! CHECK: %[[VAL_11:.*]] = constant 1 : i64
84+
! CHECK: %[[VAL_12:.*]] = constant 1 : i64
85+
! CHECK: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_7]] step %[[VAL_8]] unordered iter_args(%[[VAL_15:.*]] = %[[VAL_10]]) -> (!fir.array<10x10xi32>) {
86+
! CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_14]] : (index) -> i32
87+
! CHECK: fir.store %[[VAL_16]] to %[[VAL_0]] : !fir.ref<i32>
88+
! CHECK: %[[VAL_17:.*]] = constant 1 : i64
89+
! CHECK: %[[VAL_18:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
90+
! CHECK: %[[VAL_19:.*]] = fir.convert %[[VAL_18]] : (i32) -> i64
91+
! CHECK: %[[VAL_20:.*]] = constant 1 : i64
92+
! CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (i64) -> index
93+
! CHECK: %[[VAL_22:.*]] = constant 0 : index
94+
! CHECK: %[[VAL_23:.*]] = fir.convert %[[VAL_17]] : (i64) -> index
95+
! CHECK: %[[VAL_24:.*]] = fir.convert %[[VAL_19]] : (i64) -> index
96+
! CHECK: %[[VAL_25:.*]] = subi %[[VAL_24]], %[[VAL_23]] : index
97+
! CHECK: %[[VAL_26:.*]] = addi %[[VAL_25]], %[[VAL_21]] : index
98+
! CHECK: %[[VAL_27:.*]] = divi_signed %[[VAL_26]], %[[VAL_21]] : index
99+
! CHECK: %[[VAL_28:.*]] = cmpi sgt, %[[VAL_27]], %[[VAL_22]] : index
100+
! CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[VAL_27]], %[[VAL_22]] : index
101+
! CHECK: %[[VAL_30:.*]] = constant 1 : index
102+
! CHECK: %[[VAL_31:.*]] = constant 0 : index
103+
! CHECK: %[[VAL_32:.*]] = subi %[[VAL_29]], %[[VAL_30]] : index
104+
! CHECK: %[[VAL_33:.*]] = fir.do_loop %[[VAL_34:.*]] = %[[VAL_31]] to %[[VAL_32]] step %[[VAL_30]] unordered iter_args(%[[VAL_35:.*]] = %[[VAL_15]]) -> (!fir.array<10x10xi32>) {
105+
! CHECK: %[[VAL_36:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
106+
! CHECK: %[[VAL_37:.*]] = constant 0 : i32
107+
! CHECK: %[[VAL_38:.*]] = subi %[[VAL_37]], %[[VAL_36]] : i32
108+
! CHECK: %[[VAL_39:.*]] = fir.convert %[[VAL_11]] : (i64) -> index
109+
! CHECK: %[[VAL_40:.*]] = fir.convert %[[VAL_12]] : (i64) -> index
110+
! CHECK: %[[VAL_41:.*]] = muli %[[VAL_34]], %[[VAL_40]] : index
111+
! CHECK: %[[VAL_42:.*]] = addi %[[VAL_39]], %[[VAL_41]] : index
112112
! CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
113113
! CHECK: %[[VAL_44:.*]] = fir.convert %[[VAL_43]] : (i32) -> i64
114114
! CHECK: %[[VAL_45:.*]] = fir.convert %[[VAL_44]] : (i64) -> index
115-
! CHECK: %[[VAL_46:.*]] = fir.array_update %[[VAL_13]], %[[VAL_36]], %[[VAL_42]], %[[VAL_45]] {Fortran.offsets} : (!fir.array<10x10xi32>, i32, index, index) -> !fir.array<10x10xi32>
115+
! CHECK: %[[VAL_46:.*]] = fir.array_update %[[VAL_35]], %[[VAL_38]], %[[VAL_42]], %[[VAL_45]] {Fortran.offsets} : (!fir.array<10x10xi32>, i32, index, index) -> !fir.array<10x10xi32>
116116
! CHECK: fir.result %[[VAL_46]] : !fir.array<10x10xi32>
117117
! CHECK: }
118118
! CHECK: fir.result %[[VAL_47:.*]] : !fir.array<10x10xi32>

0 commit comments

Comments
 (0)