Skip to content

Commit ae9218c

Browse files
authored
Merge pull request #1098 from schweitzpgi/ch-where
Implement the correct splitting of FORALL and array expression loops
2 parents ca781ae + 153231e commit ae9218c

File tree

8 files changed

+584
-451
lines changed

8 files changed

+584
-451
lines changed

flang/include/flang/Lower/ConvertExpr.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#ifndef FORTRAN_LOWER_CONVERTEXPR_H
1818
#define FORTRAN_LOWER_CONVERTEXPR_H
1919

20-
#include "flang/Evaluate/shape.h"
20+
#include "flang/Evaluate/expression.h"
2121
#include "flang/Optimizer/Builder/BoxValue.h"
2222
#include "flang/Optimizer/Builder/FIRBuilder.h"
2323

@@ -169,10 +169,24 @@ createSomeArrayTempValue(AbstractConverter &converter,
169169
const evaluate::Expr<evaluate::SomeType> &expr,
170170
SymMap &symMap, StatementContext &stmtCtx);
171171

172+
// Lambda to reload the dynamically allocated pointers to a lazy buffer and its
173+
// extents. This is used to introduce these ssa-values in a place that will
174+
// dominate any/all subsequent uses after the loop that created the lazy buffer.
175+
using LoadLazyBufferLambda =
176+
std::function<std::pair<fir::ExtendedValue, mlir::Value>(
177+
fir::FirOpBuilder &)>;
178+
179+
// Creating a lazy array temporary returns a pair of values. The first is an
180+
// extended value which is a pointer to the buffer, of array type, with the
181+
// appropriate dynamic extents. The second argument is a continuation to reload
182+
// the buffer at some future point in the code gen.
183+
using CreateLazyArrayResult =
184+
std::pair<fir::ExtendedValue, LoadLazyBufferLambda>;
185+
172186
/// Like createSomeArrayTempValue, but the temporary buffer is allocated lazily
173187
/// (inside the loops instead of before the loops). This can be useful if a
174188
/// loop's bounds are functions of other loop indices, for example.
175-
fir::ExtendedValue
189+
CreateLazyArrayResult
176190
createLazyArrayTempValue(AbstractConverter &converter,
177191
const evaluate::Expr<evaluate::SomeType> &expr,
178192
mlir::Value var, mlir::Value shapeBuffer,

flang/lib/Lower/Bridge.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
12351235
mlir::Value by;
12361236
if (outermost) {
12371237
assert(headerIndex < lows.size());
1238+
if (headerIndex == 0)
1239+
explicitIterSpace.resetInnerArgs();
12381240
lb = lows[headerIndex];
12391241
ub = highs[headerIndex];
12401242
by = steps[headerIndex++];
@@ -1261,7 +1263,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
12611263
forceControlVariableBinding(ctrlVar, lp.getInductionVar());
12621264
loops.push_back(lp);
12631265
}
1264-
explicitIterSpace.setOuterLoop(loops[0]);
1266+
if (outermost)
1267+
explicitIterSpace.setOuterLoop(loops[0]);
12651268
if (const auto &mask =
12661269
std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(
12671270
header.t);
@@ -2730,13 +2733,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
27302733
auto nilSh = builder->createNullConstant(loc, shTy);
27312734
builder->create<fir::StoreOp>(loc, nilSh, shape);
27322735
implicitIterSpace.addMaskVariable(exp, var, shape);
2733-
explicitIterSpace.outermostContext().attachCleanup([=]() {
2734-
auto load = builder->create<fir::LoadOp>(loc, var);
2735-
auto cmp = builder->genIsNotNull(loc, load);
2736-
builder->genIfThen(loc, cmp)
2737-
.genThen([&]() { builder->create<fir::FreeMemOp>(loc, load); })
2738-
.end();
2739-
});
2736+
explicitIterSpace.outermostContext().attachCleanup(
2737+
[builder = this->builder, loc, var]() {
2738+
auto load = builder->create<fir::LoadOp>(loc, var);
2739+
auto cmp = builder->genIsNotNull(loc, load);
2740+
builder->genIfThen(loc, cmp)
2741+
.genThen([&]() { builder->create<fir::FreeMemOp>(loc, load); })
2742+
.end();
2743+
});
27402744
}
27412745

27422746
//===--------------------------------------------------------------------===//

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 113 additions & 49 deletions
Large diffs are not rendered by default.

flang/lib/Lower/IterationSpace.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,13 @@ class ExplicitIterSpace {
255255
innerArgs.push_back(arg);
256256
}
257257

258-
void setOuterLoop(fir::DoLoopOp loop) {
259-
if (!outerLoop.hasValue())
260-
outerLoop = loop;
261-
}
258+
/// Reset the outermost `array_load` arguments to the loop nest.
259+
void resetInnerArgs() { innerArgs = initialArgs; }
260+
261+
/// Capture the current outermost loop.
262+
void setOuterLoop(fir::DoLoopOp loop) { outerLoop = loop; }
262263

264+
/// Sets the inner loop argument at position \p offset to \p val.
263265
void setInnerArg(size_t offset, mlir::Value val) {
264266
assert(offset < innerArgs.size());
265267
innerArgs[offset] = val;
@@ -385,6 +387,7 @@ class ExplicitIterSpace {
385387
// Assignment statement context (inside the loop nest).
386388
StatementContext stmtCtx;
387389
llvm::SmallVector<mlir::Value> innerArgs;
390+
llvm::SmallVector<mlir::Value> initialArgs;
388391
llvm::Optional<fir::DoLoopOp> outerLoop;
389392
llvm::Optional<std::function<void(fir::FirOpBuilder &)>> loopCleanup;
390393
std::size_t forallContextOpen = 0;

flang/lib/Lower/SymbolMap.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ Fortran::lower::operator<<(llvm::raw_ostream &os,
7171
for (auto i : llvm::enumerate(symMap.symbolMapStack)) {
7272
os << " level " << i.index() << "<{\n";
7373
for (auto iter : i.value())
74-
os << " symbol [" << *iter.first << "] ->\n " << iter.second;
74+
os << " symbol @" << (void *)iter.first << " [" << *iter.first
75+
<< "] ->\n " << iter.second;
7576
os << " }>\n";
7677
}
7778
return os;

flang/test/Lower/forall-2.f90

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,23 @@ subroutine slice_with_explicit_iters
9696
! CHECK: %[[VAL_25:.*]] = divi_signed %[[VAL_24]], %[[VAL_19]] : index
9797
! CHECK: %[[VAL_26:.*]] = cmpi sgt, %[[VAL_25]], %[[VAL_20]] : index
9898
! CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[VAL_25]], %[[VAL_20]] : index
99-
! CHECK: %[[VAL_28:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
100-
! CHECK: %[[VAL_29:.*]] = constant 1 : index
101-
! CHECK: %[[VAL_30:.*]] = constant 0 : index
102-
! CHECK: %[[VAL_31:.*]] = subi %[[VAL_27]], %[[VAL_29]] : index
103-
! CHECK: %[[VAL_32:.*]] = fir.do_loop %[[VAL_33:.*]] = %[[VAL_30]] to %[[VAL_31]] step %[[VAL_29]] unordered iter_args(%[[VAL_34:.*]] = %[[VAL_10]]) -> (!fir.array<10x10xi32>) {
99+
! CHECK: %[[VAL_28:.*]] = constant 1 : index
100+
! CHECK: %[[VAL_29:.*]] = constant 0 : index
101+
! CHECK: %[[VAL_30:.*]] = subi %[[VAL_27]], %[[VAL_28]] : index
102+
! CHECK: %[[VAL_31:.*]] = fir.do_loop %[[VAL_32:.*]] = %[[VAL_29]] to %[[VAL_30]] step %[[VAL_28]] unordered iter_args(%[[VAL_33:.*]] = %[[VAL_13]]) -> (!fir.array<10x10xi32>) {
103+
! CHECK: %[[VAL_34:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
104104
! CHECK: %[[VAL_35:.*]] = constant 0 : i32
105-
! CHECK: %[[VAL_36:.*]] = subi %[[VAL_35]], %[[VAL_28]] : i32
105+
! CHECK: %[[VAL_36:.*]] = subi %[[VAL_35]], %[[VAL_34]] : i32
106106
! CHECK: %[[VAL_37:.*]] = constant 1 : i64
107107
! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_37]] : (i64) -> index
108108
! CHECK: %[[VAL_39:.*]] = constant 1 : i64
109109
! CHECK: %[[VAL_40:.*]] = fir.convert %[[VAL_39]] : (i64) -> index
110-
! CHECK: %[[VAL_41:.*]] = muli %[[VAL_33]], %[[VAL_40]] : index
110+
! CHECK: %[[VAL_41:.*]] = muli %[[VAL_32]], %[[VAL_40]] : index
111111
! CHECK: %[[VAL_42:.*]] = addi %[[VAL_38]], %[[VAL_41]] : index
112112
! CHECK: %[[VAL_43:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
113113
! CHECK: %[[VAL_44:.*]] = fir.convert %[[VAL_43]] : (i32) -> i64
114114
! CHECK: %[[VAL_45:.*]] = fir.convert %[[VAL_44]] : (i64) -> index
115-
! CHECK: %[[VAL_46:.*]] = fir.array_update %[[VAL_13]], %[[VAL_36]], %[[VAL_42]], %[[VAL_45]] {Fortran.offsets} : (!fir.array<10x10xi32>, i32, index, index) -> !fir.array<10x10xi32>
115+
! CHECK: %[[VAL_46:.*]] = fir.array_update %[[VAL_33]], %[[VAL_36]], %[[VAL_42]], %[[VAL_45]] {Fortran.offsets} : (!fir.array<10x10xi32>, i32, index, index) -> !fir.array<10x10xi32>
116116
! CHECK: fir.result %[[VAL_46]] : !fir.array<10x10xi32>
117117
! CHECK: }
118118
! CHECK: fir.result %[[VAL_47:.*]] : !fir.array<10x10xi32>

0 commit comments

Comments
 (0)