Skip to content

Commit 9b5ca90

Browse files
committed
Lower user defined assignments inside WHERE and FORALL
- Update the ScalarArrayExprLowering to add a userAssign entry point that uses the CustomCopyinCopyOut semantics. - Thread explicit and implicit context in createSubroutineCall. - When the explicit context is active (inside a FORALL), and when the assignment LHS is a scalar, use the ScalarArrayExprLowering newly added userAssign entry point. - Thread implicit and explicit in lowerElementalUserAssignment to cover the array case inside WHERE and FORALL. - Ensure analyzeExplicitSpace (FORALL analysis) sees and caches the same expression address pointers as the one used in lowering. This requires looking at the ProcRef operands of evaluate:Assign, and not its LHS/RHS. This also requires not wrapping the ProcRef in an Expr when threading it, because it would make a copy. - Apply CustomCopyinCopyOut in ArrayExprLowering::applyPathToArrayLoad to create fir.array_modify for array assignments inside FORALL. The case where a non elemental user defined array assignment is used inside FORALL is left TODO, because it is not clear to me how one can operate on sub-array parts inside the loops.
1 parent dd6dd79 commit 9b5ca90

File tree

4 files changed

+812
-61
lines changed

4 files changed

+812
-61
lines changed

flang/include/flang/Lower/ConvertExpr.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ createSomeArrayBox(AbstractConverter &converter,
205205
/// returns, the returned value indicates which label the code should jump to.
206206
/// The returned value is null otherwise.
207207
mlir::Value createSubroutineCall(AbstractConverter &converter,
208-
const evaluate::Expr<evaluate::SomeType> &call,
208+
const evaluate::ProcedureRef &call,
209+
ExplicitIterSpace &explicitIterSpace,
210+
ImplicitIterSpace &implicitIterSpace,
209211
SymMap &symMap, StatementContext &stmtCtx,
210212
bool isUserDefAssignment);
211213

flang/lib/Lower/Bridge.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -661,9 +661,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
661661
setCurrentPosition(stmt.v.source);
662662
assert(stmt.typedCall && "Call was not analyzed");
663663
// Call statement lowering shares code with function call lowering.
664-
Fortran::semantics::SomeExpr expr{*stmt.typedCall};
665664
auto res = Fortran::lower::createSubroutineCall(
666-
*this, expr, localSymbols, stmtCtx, /*isUserDefAssignment=*/false);
665+
*this, *stmt.typedCall, explicitIterSpace, implicitIterSpace,
666+
localSymbols, stmtCtx, /*isUserDefAssignment=*/false);
667667
if (!res)
668668
return; // "Normal" subroutine call.
669669
// Call with alternate return specifiers.
@@ -1941,18 +1941,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19411941
// [2] User defined assignment. If the context is a scalar
19421942
// expression then call the procedure.
19431943
[&](const Fortran::evaluate::ProcedureRef &procRef) {
1944-
if (implicitIterationSpace())
1945-
TODO(loc, "user defined assignment within WHERE");
1946-
1947-
Fortran::semantics::SomeExpr expr{procRef};
19481944
auto &ctx = explicitIterationSpace()
19491945
? explicitIterSpace.stmtContext()
19501946
: stmtCtx;
19511947
Fortran::lower::createSubroutineCall(
1952-
*this, expr, localSymbols, ctx, /*isUserDefAssignment=*/true);
1953-
if (explicitIterationSpace())
1954-
builder->create<fir::ResultOp>(
1955-
loc, explicitIterSpace.getInnerArgs());
1948+
*this, procRef, explicitIterSpace, implicitIterSpace,
1949+
localSymbols, ctx, /*isUserDefAssignment=*/true);
19561950
},
19571951

19581952
// [3] Pointer assignment with possibly empty bounds-spec. R1035: a
@@ -2595,8 +2589,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25952589
explicitIterSpace.exprBase(&e, LHS);
25962590
}
25972591
void analyzeExplicitSpace(const Fortran::evaluate::Assignment *assign) {
2598-
analyzeExplicitSpace</*LHS=*/true>(assign->lhs);
2599-
analyzeExplicitSpace(assign->rhs);
2592+
auto analyzeAssign = [&](const Fortran::lower::SomeExpr &lhs,
2593+
const Fortran::lower::SomeExpr &rhs) {
2594+
analyzeExplicitSpace</*LHS=*/true>(lhs);
2595+
analyzeExplicitSpace(rhs);
2596+
};
2597+
std::visit(
2598+
Fortran::common::visitors{
2599+
[&](const Fortran::evaluate::ProcedureRef &procRef) {
2600+
// Ensure the procRef expressions are the one being visited.
2601+
assert(procRef.arguments().size() == 2);
2602+
const auto *lhs = procRef.arguments()[0].value().UnwrapExpr();
2603+
const auto *rhs = procRef.arguments()[1].value().UnwrapExpr();
2604+
assert(lhs && rhs &&
2605+
"user defined assignment arguments must be expressions");
2606+
analyzeAssign(*lhs, *rhs);
2607+
},
2608+
[&](const auto &) { analyzeAssign(assign->lhs, assign->rhs); }},
2609+
assign->u);
26002610
explicitIterSpace.endAssign();
26012611
}
26022612
void analyzeExplicitSpace(const Fortran::parser::ForallAssignmentStmt &stmt) {

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 162 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -271,25 +271,6 @@ static fir::ExtendedValue arrayLoadExtValue(fir::FirOpBuilder &builder,
271271
"properties are explicit, assumed, deferred, or ?");
272272
}
273273

274-
/// Convert the result of a fir.array_modify to an ExtendedValue given the
275-
/// related fir.array_load.
276-
static fir::ExtendedValue arrayModifyToExv(fir::FirOpBuilder &builder,
277-
mlir::Location loc,
278-
fir::ArrayLoadOp load,
279-
mlir::Value elementAddr) {
280-
auto eleTy = fir::unwrapPassByRefType(elementAddr.getType());
281-
if (fir::isa_char(eleTy)) {
282-
auto len = fir::factory::CharacterExprHelper{builder, loc}.getLength(
283-
load.memref());
284-
if (!len) {
285-
assert(load.typeparams().size() == 1 && "length must be in array_load");
286-
len = load.typeparams()[0];
287-
}
288-
return fir::CharBoxValue{elementAddr, len};
289-
}
290-
return elementAddr;
291-
}
292-
293274
/// Is this a call to an elemental procedure with at least one array argument ?
294275
static bool
295276
isElementalProcWithArrayArgs(const Fortran::evaluate::ProcedureRef &procRef) {
@@ -2488,6 +2469,25 @@ static void genScalarUserDefinedAssignmentCall(fir::FirOpBuilder &builder,
24882469
builder.create<fir::CallOp>(loc, func, mlir::ValueRange{lhsArg, rhsArg});
24892470
}
24902471

2472+
/// Convert the result of a fir.array_modify to an ExtendedValue given the
2473+
/// related fir.array_load.
2474+
static fir::ExtendedValue arrayModifyToExv(fir::FirOpBuilder &builder,
2475+
mlir::Location loc,
2476+
fir::ArrayLoadOp load,
2477+
mlir::Value elementAddr) {
2478+
auto eleTy = fir::unwrapPassByRefType(elementAddr.getType());
2479+
if (fir::isa_char(eleTy)) {
2480+
auto len = fir::factory::CharacterExprHelper{builder, loc}.getLength(
2481+
load.memref());
2482+
if (!len) {
2483+
assert(load.typeparams().size() == 1 && "length must be in array_load");
2484+
len = load.typeparams()[0];
2485+
}
2486+
return fir::CharBoxValue{elementAddr, len};
2487+
}
2488+
return elementAddr;
2489+
}
2490+
24912491
//===----------------------------------------------------------------------===//
24922492
//
24932493
// Lowering of scalar expressions in an explicit iteration space context.
@@ -2540,7 +2540,8 @@ class ScalarArrayExprLowering {
25402540
// 3) Finalize the inner context.
25412541
expSpace.finalizeContext();
25422542
// 4) Thread the array value updated forward. Note: the lhs might be
2543-
// ill-formed, in which case there is no array to thread.
2543+
// ill-formed (performing scalar assignment in an array context),
2544+
// in which case there is no array to thread.
25442545
if (auto updateOp = mlir::dyn_cast<fir::ArrayUpdateOp>(
25452546
fir::getBase(lexv).getDefiningOp())) {
25462547
auto oldInnerArg = updateOp.sequence();
@@ -2557,6 +2558,45 @@ class ScalarArrayExprLowering {
25572558
return lexv;
25582559
}
25592560

2561+
ExtValue userAssign(mlir::FuncOp userAssignment,
2562+
const Fortran::lower::SomeExpr &lhs,
2563+
const Fortran::lower::SomeExpr &rhs) {
2564+
auto loc = getLoc();
2565+
semant = ConstituentSemantics::RefTransparent;
2566+
// 1) Lower the rhs expression with array_fetch op(s).
2567+
auto rexv = lower(rhs);
2568+
// 2) Lower the lhs expression to an array_modify.
2569+
semant = ConstituentSemantics::CustomCopyInCopyOut;
2570+
auto lexv = lower(lhs);
2571+
bool isIllFormedLHS = false;
2572+
// 3) Insert the call
2573+
if (auto modifyOp = mlir::dyn_cast<fir::ArrayModifyOp>(
2574+
fir::getBase(lexv).getDefiningOp())) {
2575+
auto oldInnerArg = modifyOp.sequence();
2576+
auto offset = expSpace.argPosition(oldInnerArg);
2577+
expSpace.setInnerArg(offset, fir::getBase(lexv));
2578+
auto exv =
2579+
arrayModifyToExv(builder, loc, expSpace.getLhsLoad(0).getValue(),
2580+
modifyOp.getResult(0));
2581+
genScalarUserDefinedAssignmentCall(builder, loc, userAssignment, exv,
2582+
rexv);
2583+
} else {
2584+
// LHS is ill formed, it is a scalar with no references to FORALL
2585+
// subscripts, so there is actually no array assignment here. The user
2586+
// code is probably bad, but still insert user assignment call since it
2587+
// was not rejected by semantics (a warning was emitted).
2588+
isIllFormedLHS = true;
2589+
genScalarUserDefinedAssignmentCall(builder, getLoc(), userAssignment,
2590+
lexv, rexv);
2591+
}
2592+
// 4) Finalize the inner context.
2593+
expSpace.finalizeContext();
2594+
// 5). Thread the array value updated forward.
2595+
if (!isIllFormedLHS)
2596+
builder.create<fir::ResultOp>(getLoc(), fir::getBase(lexv));
2597+
return lexv;
2598+
}
2599+
25602600
private:
25612601
bool pathIsEmpty() { return reversePath.empty(); }
25622602

@@ -2605,9 +2645,9 @@ class ScalarArrayExprLowering {
26052645
ExtValue applyPathToArrayLoad(fir::ArrayLoadOp load) {
26062646
auto loc = getLoc();
26072647
ExtValue result;
2648+
auto path = lowerPath(load.getType());
26082649
if (semant == ConstituentSemantics::ProjectedCopyInCopyOut) {
26092650
auto innerArg = expSpace.findArgumentOfLoad(load);
2610-
auto path = lowerPath(load.getType());
26112651
auto eleTy = fir::applyPathToType(innerArg.getType(), path);
26122652
auto toTy = adjustedArrayElementType(eleTy);
26132653
auto castedElement = builder.createConvert(loc, toTy, elementalValue);
@@ -2618,8 +2658,22 @@ class ScalarArrayExprLowering {
26182658
update->setAttr(fir::factory::attrFortranArrayOffsets(),
26192659
builder.getUnitAttr());
26202660
result = arrayLoadExtValue(builder, loc, load, {}, update);
2661+
} else if (semant == ConstituentSemantics::CustomCopyInCopyOut) {
2662+
// Create an array_modify to get the LHS element address and indicate
2663+
// the assignment, and create the call to the user defined assignment.
2664+
auto innerArg = expSpace.findArgumentOfLoad(load);
2665+
auto eleTy = fir::applyPathToType(innerArg.getType(), path);
2666+
auto refEleTy =
2667+
fir::isa_ref_type(eleTy) ? eleTy : builder.getRefType(eleTy);
2668+
auto arrModify = builder.create<fir::ArrayModifyOp>(
2669+
loc, mlir::TypeRange{refEleTy, innerArg.getType()}, innerArg, path,
2670+
load.typeparams());
2671+
// Flag the offsets as "Fortran" as they are not zero-origin.
2672+
arrModify->setAttr(fir::factory::attrFortranArrayOffsets(),
2673+
builder.getUnitAttr());
2674+
result =
2675+
arrayLoadExtValue(builder, loc, load, {}, arrModify.getResult(1));
26212676
} else {
2622-
auto path = lowerPath(load.getType());
26232677
auto eleTy = fir::applyPathToType(load.getType(), path);
26242678
assert(eleTy && "path did not apply to type");
26252679
auto resTy = adjustedArrayElementType(eleTy);
@@ -3715,13 +3769,16 @@ class ArrayExprLowering {
37153769
builder.restoreInsertionPoint(insPt);
37163770
}
37173771

3718-
static void lowerElementalUserAssignment(
3719-
Fortran::lower::AbstractConverter &converter,
3720-
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3721-
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &call) {
3772+
static void
3773+
lowerElementalUserAssignment(Fortran::lower::AbstractConverter &converter,
3774+
Fortran::lower::SymMap &symMap,
3775+
Fortran::lower::StatementContext &stmtCtx,
3776+
Fortran::lower::ExplicitIterSpace &explicitSpace,
3777+
Fortran::lower::ImplicitIterSpace &implicitSpace,
3778+
const Fortran::evaluate::ProcedureRef &procRef) {
37223779
ArrayExprLowering ael(converter, stmtCtx, symMap,
3723-
ConstituentSemantics::CustomCopyInCopyOut);
3724-
auto procRef = std::get<Fortran::evaluate::ProcedureRef>(call.u);
3780+
ConstituentSemantics::CustomCopyInCopyOut,
3781+
&explicitSpace, &implicitSpace);
37253782
assert(procRef.arguments().size() == 2);
37263783
const auto *lhs = procRef.arguments()[0].value().UnwrapExpr();
37273784
const auto *rhs = procRef.arguments()[1].value().UnwrapExpr();
@@ -3752,9 +3809,12 @@ class ArrayExprLowering {
37523809
determineShapeOfDest(lhs);
37533810
semant = ConstituentSemantics::RefTransparent;
37543811
auto exv = lowerArrayExpression(rhs);
3755-
builder.create<fir::ArrayMergeStoreOp>(
3756-
loc, destination, fir::getBase(exv), destination.memref(),
3757-
destination.slice(), destination.typeparams());
3812+
if (explicitSpaceIsActive())
3813+
builder.create<fir::ResultOp>(loc, fir::getBase(exv));
3814+
else
3815+
builder.create<fir::ArrayMergeStoreOp>(
3816+
loc, destination, fir::getBase(exv), destination.memref(),
3817+
destination.slice(), destination.typeparams());
37583818
}
37593819

37603820
/// Compute the shape of a slice.
@@ -5914,6 +5974,25 @@ class ArrayExprLowering {
59145974
return arrayLoadExtValue(builder, loc, load, {}, update);
59155975
};
59165976
}
5977+
if (semant == ConstituentSemantics::CustomCopyInCopyOut) {
5978+
// Create an array_modify to get the LHS element address and indicate
5979+
// the assignment, and create the call to the user defined assignment.
5980+
destination = load;
5981+
auto innerArg = explicitSpace->findArgumentOfLoad(load);
5982+
return [=](IterSpace iters) mutable {
5983+
auto [path, eleTy] = lowerPath(loc, revPath, load.getType(), iters);
5984+
auto refEleTy =
5985+
fir::isa_ref_type(eleTy) ? eleTy : builder.getRefType(eleTy);
5986+
auto arrModify = builder.create<fir::ArrayModifyOp>(
5987+
loc, mlir::TypeRange{refEleTy, innerArg.getType()}, innerArg, path,
5988+
load.typeparams());
5989+
// Flag the offsets as "Fortran" as they are not zero-origin.
5990+
arrModify->setAttr(fir::factory::attrFortranArrayOffsets(),
5991+
builder.getUnitAttr());
5992+
return arrayLoadExtValue(builder, loc, load, {},
5993+
arrModify.getResult(1));
5994+
};
5995+
}
59175996
return [=](IterSpace iters) mutable {
59185997
auto [path, eleTy] = lowerPath(loc, revPath, load.getType(), iters);
59195998
auto resTy = adjustedArrayElementType(eleTy);
@@ -6232,26 +6311,61 @@ fir::MutableBoxValue Fortran::lower::createMutableBox(
62326311
}
62336312

62346313
mlir::Value Fortran::lower::createSubroutineCall(
6235-
AbstractConverter &converter,
6236-
const evaluate::Expr<evaluate::SomeType> &call, SymMap &symMap,
6237-
StatementContext &stmtCtx, bool isUserDefAssignment) {
6314+
AbstractConverter &converter, const evaluate::ProcedureRef &call,
6315+
ExplicitIterSpace &explicitIterSpace, ImplicitIterSpace &implicitIterSpace,
6316+
SymMap &symMap, StatementContext &stmtCtx, bool isUserDefAssignment) {
62386317
auto loc = converter.getCurrentLocation();
6318+
6319+
if (isUserDefAssignment) {
6320+
assert(call.arguments().size() == 2);
6321+
const auto *lhs = call.arguments()[0].value().UnwrapExpr();
6322+
const auto *rhs = call.arguments()[1].value().UnwrapExpr();
6323+
assert(lhs && rhs &&
6324+
"user defined assignment arguments must be expressions");
6325+
if (call.IsElemental() && lhs->Rank() > 0) {
6326+
// Elemental user defined assignment has special requirements to deal with
6327+
// LHS/RHS overlaps. See 10.2.1.5 p2.
6328+
ArrayExprLowering::lowerElementalUserAssignment(
6329+
converter, symMap, stmtCtx, explicitIterSpace, implicitIterSpace,
6330+
call);
6331+
} else if (explicitIterSpace.isActive() && lhs->Rank() == 0) {
6332+
// Scalar defined assignment (elemental or not) in a FORALL context.
6333+
auto func = Fortran::lower::CallerInterface(call, converter).getFuncOp();
6334+
ScalarArrayExprLowering sael(converter, symMap, explicitIterSpace,
6335+
stmtCtx);
6336+
sael.userAssign(func, *lhs, *rhs);
6337+
} else if (explicitIterSpace.isActive()) {
6338+
// TODO: need to array fetch/modify sub-arrays ?
6339+
TODO(loc, "non elemental user defined array assignment inside FORALL");
6340+
} else {
6341+
if (!implicitIterSpace.empty())
6342+
fir::emitFatalError(
6343+
loc,
6344+
"C1032: user defined assignment inside WHERE must be elemental");
6345+
// Non elemental user defined assignment outside of FORALL and WHERE.
6346+
// FIXME: The non elemental user defined assignment case with array
6347+
// arguments must be take into account potential overlap. So far the front
6348+
// end does not add parentheses around the RHS argument in the call as it
6349+
// should according to 15.4.3.4.3 p2.
6350+
Fortran::semantics::SomeExpr expr{call};
6351+
Fortran::lower::createSomeExtendedExpression(loc, converter, expr, symMap,
6352+
stmtCtx);
6353+
}
6354+
return {};
6355+
}
6356+
6357+
assert(implicitIterSpace.empty() && !explicitIterSpace.isActive() &&
6358+
"subroutine calls are not allowed inside WHERE and FORALL");
6359+
62396360
if (isElementalProcWithArrayArgs(call)) {
6240-
// Elemental user defined assignment has special requirements to deal with
6241-
// LHS/RHS overlaps. See 10.2.1.5 p2.
6242-
if (isUserDefAssignment)
6243-
ArrayExprLowering::lowerElementalUserAssignment(converter, symMap,
6244-
stmtCtx, call);
6245-
else
6246-
ArrayExprLowering::lowerArrayElementalSubroutine(converter, symMap,
6247-
stmtCtx, call);
6248-
return mlir::Value{};
6361+
Fortran::semantics::SomeExpr expr{call};
6362+
ArrayExprLowering::lowerArrayElementalSubroutine(converter, symMap, stmtCtx,
6363+
expr);
6364+
return {};
62496365
}
6250-
// FIXME: The non elemental user defined assignment case with array arguments
6251-
// must be take into account potential overlap. So far the front end does not
6252-
// add parentheses around the RHS argument in the call as it should according
6253-
// to 15.4.3.4.3 p2.
6254-
auto res = Fortran::lower::createSomeExtendedExpression(loc, converter, call,
6366+
// Simple subroutine call, with potential alternate return.
6367+
Fortran::semantics::SomeExpr expr{call};
6368+
auto res = Fortran::lower::createSomeExtendedExpression(loc, converter, expr,
62556369
symMap, stmtCtx);
62566370
return fir::getBase(res);
62576371
}

0 commit comments

Comments
 (0)