Skip to content

Commit dd6dd79

Browse files
committed
[lowering] Use fir.array_modify to lower elemental defined assignments
- Add a CustomCopyInCopyOut ConstituentSemantics to create fir.array_modify in the expression lowering. - Add a helper function to generate user assignment function call (packaging the arguments) based on an LHS and RHS ExtendedValue for the elements. - Use the address returned by the fir.array_modify modify to build the LHS ExtendedValue and used the added helper function to generate the call. - Add some better support to get the length form a memref mlir::Value (do not only try to get a constant length, but also read it from the fir.box if the memref is a fir.box). Move this code to CharacterExprHelper. This is so far only enabled outside of Forall/Where contexts, another place creating fir.array_update in these contexts will need to be updated to handle the CustomCopyInCopyOut.
1 parent 4033624 commit dd6dd79

File tree

4 files changed

+375
-17
lines changed

4 files changed

+375
-17
lines changed

flang/include/flang/Optimizer/Builder/Character.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ class CharacterExprHelper {
153153
/// Returns integer value held in a character singleton.
154154
mlir::Value extractCodeFromSingleton(mlir::Value singleton);
155155

156+
/// Create a value for the length of a character based on its memory reference
157+
/// that may be a boxchar, box or !fir.[ptr|ref|heap]<fir.char<kind, len>>. If
158+
/// the memref is a simple address and the length is not constant in type, the
159+
/// returned length will be empty.
160+
mlir::Value getLength(mlir::Value memref);
161+
156162
/// Compute length given a fir.box describing a character entity.
157163
/// It adjusts the length from the number of bytes per the descriptor
158164
/// to the number of characters per the Fortran KIND.

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 159 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ enum class ConstituentSemantics {
136136
// Similar to CopyInCopyOut but `a` may be a transient projection (rather than
137137
// a whole array).
138138
ProjectedCopyInCopyOut,
139+
// Similar to ProjectedCopyInCopyOut, except the merge value is not assigned
140+
// automatically by the framework. Instead, and address for `[xs]` is made
141+
// accessible so that custom assignments to `[xs]` can be implemented.
142+
CustomCopyInCopyOut,
139143
// Referentially opaque. Refers to the address of `x_i`.
140144
RefOpaque
141145
};
@@ -219,18 +223,19 @@ static fir::ExtendedValue arrayLoadExtValue(fir::FirOpBuilder &builder,
219223
// Recover the extended value from the load.
220224
assert(!load.slice() && "slice is not allowed");
221225
auto arrTy = load.getType();
222-
auto idxTy = builder.getIndexType();
223226
if (!path.empty()) {
224227
auto ty = fir::applyPathToType(arrTy, path);
225228
if (!ty)
226229
fir::emitFatalError(loc, "path does not apply to type");
227230
if (!ty.isa<fir::SequenceType>()) {
228-
if (auto charTy = ty.dyn_cast<fir::CharacterType>()) {
229-
// ???: Is this in CharacterExprHelper?
230-
auto len = charTy.hasConstantLen()
231-
? builder.createIntegerConstant(
232-
loc, idxTy, ty.cast<fir::CharacterType>().getLen())
233-
: load.typeparams()[0];
231+
if (fir::isa_char(ty)) {
232+
auto len = fir::factory::CharacterExprHelper{builder, loc}.getLength(
233+
load.memref());
234+
if (!len) {
235+
assert(load.typeparams().size() == 1 &&
236+
"length must be in array_load");
237+
len = load.typeparams()[0];
238+
}
234239
return fir::CharBoxValue{newBase, len};
235240
}
236241
return newBase;
@@ -250,11 +255,13 @@ static fir::ExtendedValue arrayLoadExtValue(fir::FirOpBuilder &builder,
250255
}
251256
auto extents = fir::factory::getExtents(load.shape());
252257
auto lbounds = fir::factory::getOrigins(load.shape());
253-
if (auto charTy = eleTy.dyn_cast<fir::CharacterType>()) {
254-
auto len = charTy.hasConstantLen()
255-
? builder.createIntegerConstant(
256-
loc, idxTy, eleTy.cast<fir::CharacterType>().getLen())
257-
: load.typeparams()[0];
258+
if (fir::isa_char(eleTy)) {
259+
auto len = fir::factory::CharacterExprHelper{builder, loc}.getLength(
260+
load.memref());
261+
if (!len) {
262+
assert(load.typeparams().size() == 1 && "length must be in array_load");
263+
len = load.typeparams()[0];
264+
}
258265
return fir::CharArrayBoxValue{newBase, len, extents, lbounds};
259266
}
260267
if (load.typeparams().empty()) {
@@ -264,6 +271,25 @@ static fir::ExtendedValue arrayLoadExtValue(fir::FirOpBuilder &builder,
264271
"properties are explicit, assumed, deferred, or ?");
265272
}
266273

274+
/// Convert the result of a fir.array_modify to an ExtendedValue given the
275+
/// related fir.array_load.
276+
static fir::ExtendedValue arrayModifyToExv(fir::FirOpBuilder &builder,
277+
mlir::Location loc,
278+
fir::ArrayLoadOp load,
279+
mlir::Value elementAddr) {
280+
auto eleTy = fir::unwrapPassByRefType(elementAddr.getType());
281+
if (fir::isa_char(eleTy)) {
282+
auto len = fir::factory::CharacterExprHelper{builder, loc}.getLength(
283+
load.memref());
284+
if (!len) {
285+
assert(load.typeparams().size() == 1 && "length must be in array_load");
286+
len = load.typeparams()[0];
287+
}
288+
return fir::CharBoxValue{elementAddr, len};
289+
}
290+
return elementAddr;
291+
}
292+
267293
/// Is this a call to an elemental procedure with at least one array argument ?
268294
static bool
269295
isElementalProcWithArrayArgs(const Fortran::evaluate::ProcedureRef &procRef) {
@@ -1835,7 +1861,7 @@ class ScalarExprLowering {
18351861
});
18361862
/// Result lengths parameters should not be provided to box storage
18371863
/// allocation and save_results, but they are still useful information to
1838-
/// keep in the ExtentdedValue if non-deferred.
1864+
/// keep in the ExtendedValue if non-deferred.
18391865
if (!type.isa<fir::BoxType>())
18401866
resultLengths = lengths;
18411867
auto temp =
@@ -2418,6 +2444,50 @@ static mlir::Value adjustedArrayElement(mlir::Location loc,
24182444
return builder.createConvert(loc, adjustedArrayElementType(eleTy), val);
24192445
}
24202446

2447+
/// Helper to generate calls to scalar user defined assignment procedures.
2448+
static void genScalarUserDefinedAssignmentCall(fir::FirOpBuilder &builder,
2449+
mlir::Location loc,
2450+
mlir::FuncOp func,
2451+
const fir::ExtendedValue &lhs,
2452+
const fir::ExtendedValue &rhs) {
2453+
auto prepareUserDefinedArg =
2454+
[](fir::FirOpBuilder &builder, mlir::Location loc,
2455+
const fir::ExtendedValue &value, mlir::Type argType) -> mlir::Value {
2456+
if (argType.isa<fir::BoxCharType>()) {
2457+
const auto *charBox = value.getCharBox();
2458+
assert(charBox && "argument type mismatch in elemental user assignment");
2459+
return fir::factory::CharacterExprHelper{builder, loc}.createEmbox(
2460+
*charBox);
2461+
}
2462+
if (argType.isa<fir::BoxType>()) {
2463+
auto box = builder.createBox(loc, value);
2464+
return builder.createConvert(loc, argType, box);
2465+
}
2466+
// Simple pass by address.
2467+
auto argBaseType = fir::unwrapRefType(argType);
2468+
assert(!fir::hasDynamicSize(argBaseType));
2469+
auto from = fir::getBase(value);
2470+
if (argBaseType != fir::unwrapRefType(from.getType())) {
2471+
// With logicals, it is possible that from is i1 here.
2472+
if (fir::isa_ref_type(from.getType()))
2473+
from = builder.create<fir::LoadOp>(loc, from);
2474+
from = builder.createConvert(loc, argBaseType, from);
2475+
}
2476+
if (!fir::isa_ref_type(from.getType())) {
2477+
auto temp = builder.createTemporary(loc, argBaseType);
2478+
builder.create<fir::StoreOp>(loc, from, temp);
2479+
from = temp;
2480+
}
2481+
return builder.createConvert(loc, argType, from);
2482+
};
2483+
assert(func.getNumArguments() == 2);
2484+
auto lhsType = func.getType().getInput(0);
2485+
auto rhsType = func.getType().getInput(1);
2486+
auto lhsArg = prepareUserDefinedArg(builder, loc, lhs, lhsType);
2487+
auto rhsArg = prepareUserDefinedArg(builder, loc, rhs, rhsType);
2488+
builder.create<fir::CallOp>(loc, func, mlir::ValueRange{lhsArg, rhsArg});
2489+
}
2490+
24212491
//===----------------------------------------------------------------------===//
24222492
//
24232493
// Lowering of scalar expressions in an explicit iteration space context.
@@ -2477,6 +2547,12 @@ class ScalarArrayExprLowering {
24772547
auto offset = expSpace.argPosition(oldInnerArg);
24782548
expSpace.setInnerArg(offset, fir::getBase(lexv));
24792549
builder.create<fir::ResultOp>(getLoc(), fir::getBase(lexv));
2550+
} else if (auto updateOp = mlir::dyn_cast<fir::ArrayModifyOp>(
2551+
fir::getBase(lexv).getDefiningOp())) {
2552+
auto oldInnerArg = updateOp.sequence();
2553+
auto offset = expSpace.argPosition(oldInnerArg);
2554+
expSpace.setInnerArg(offset, fir::getBase(lexv));
2555+
builder.create<fir::ResultOp>(getLoc(), fir::getBase(lexv));
24802556
}
24812557
return lexv;
24822558
}
@@ -3619,6 +3695,8 @@ class ArrayExprLowering {
36193695
return abstractArrayExtValue(iterSpace.outerResult());
36203696
}
36213697

3698+
/// Lower an elemental subroutine call with at least one array argument.
3699+
/// Not for user defined assignments.
36223700
static void lowerArrayElementalSubroutine(
36233701
Fortran::lower::AbstractConverter &converter,
36243702
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
@@ -3628,7 +3706,6 @@ class ArrayExprLowering {
36283706
ael.lowerArrayElementalSubroutine(call);
36293707
}
36303708

3631-
// ! Not for user defined assignment elemental subroutine.
36323709
void lowerArrayElementalSubroutine(
36333710
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &call) {
36343711
auto f = genarr(call);
@@ -3638,6 +3715,48 @@ class ArrayExprLowering {
36383715
builder.restoreInsertionPoint(insPt);
36393716
}
36403717

3718+
static void lowerElementalUserAssignment(
3719+
Fortran::lower::AbstractConverter &converter,
3720+
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3721+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &call) {
3722+
ArrayExprLowering ael(converter, stmtCtx, symMap,
3723+
ConstituentSemantics::CustomCopyInCopyOut);
3724+
auto procRef = std::get<Fortran::evaluate::ProcedureRef>(call.u);
3725+
assert(procRef.arguments().size() == 2);
3726+
const auto *lhs = procRef.arguments()[0].value().UnwrapExpr();
3727+
const auto *rhs = procRef.arguments()[1].value().UnwrapExpr();
3728+
assert(lhs && rhs &&
3729+
"user defined assignment arguments must be expressions");
3730+
auto func = Fortran::lower::CallerInterface(procRef, converter).getFuncOp();
3731+
ael.lowerElementalUserAssignment(func, *lhs, *rhs);
3732+
}
3733+
3734+
void lowerElementalUserAssignment(
3735+
mlir::FuncOp userAssignment,
3736+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &lhs,
3737+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &rhs) {
3738+
auto loc = getLoc();
3739+
PushSemantics(ConstituentSemantics::CustomCopyInCopyOut);
3740+
auto genArrayModify = genarr(lhs);
3741+
ccStoreToDest = [=](IterSpace iters) -> ExtValue {
3742+
auto modifiedArray = genArrayModify(iters);
3743+
auto arrayModify = mlir::dyn_cast_or_null<fir::ArrayModifyOp>(
3744+
fir::getBase(modifiedArray).getDefiningOp());
3745+
assert(arrayModify && "must be created by ArrayModifyOp");
3746+
auto lhs =
3747+
arrayModifyToExv(builder, loc, destination, arrayModify.getResult(0));
3748+
genScalarUserDefinedAssignmentCall(builder, loc, userAssignment, lhs,
3749+
iters.elementExv());
3750+
return modifiedArray;
3751+
};
3752+
determineShapeOfDest(lhs);
3753+
semant = ConstituentSemantics::RefTransparent;
3754+
auto exv = lowerArrayExpression(rhs);
3755+
builder.create<fir::ArrayMergeStoreOp>(
3756+
loc, destination, fir::getBase(exv), destination.memref(),
3757+
destination.slice(), destination.typeparams());
3758+
}
3759+
36413760
/// Compute the shape of a slice.
36423761
llvm::SmallVector<mlir::Value> computeSliceShape(mlir::Value slice) {
36433762
llvm::SmallVector<mlir::Value> slicedShape;
@@ -4994,6 +5113,23 @@ class ArrayExprLowering {
49945113
return abstractArrayExtValue(arrUpdate);
49955114
};
49965115
}
5116+
if (isCustomCopyInCopyOut()) {
5117+
// Create an array_modify to get the LHS element address and indicate
5118+
// the assignment, the actual assignment must be implemented in
5119+
// ccStoreToDest.
5120+
destination = arrLoad;
5121+
return [=](IterSpace iters) -> ExtValue {
5122+
auto innerArg = iters.innerArgument();
5123+
auto resTy = innerArg.getType();
5124+
auto eleTy = fir::applyPathToType(resTy, iters.iterVec());
5125+
auto refEleTy =
5126+
fir::isa_ref_type(eleTy) ? eleTy : builder.getRefType(eleTy);
5127+
auto arrModify = builder.create<fir::ArrayModifyOp>(
5128+
loc, mlir::TypeRange{refEleTy, resTy}, innerArg, iters.iterVec(),
5129+
destination.typeparams());
5130+
return abstractArrayExtValue(arrModify.getResult(1));
5131+
};
5132+
}
49975133
if (isCopyInCopyOut()) {
49985134
// Semantics are copy-in copy-out.
49995135
// The continuation simply forwards the result of the `array_load` Op,
@@ -5914,6 +6050,10 @@ class ArrayExprLowering {
59146050
return semant == ConstituentSemantics::ProjectedCopyInCopyOut;
59156051
}
59166052

6053+
bool isCustomCopyInCopyOut() {
6054+
return semant == ConstituentSemantics::CustomCopyInCopyOut;
6055+
}
6056+
59176057
/// Array appears in a context where it must be boxed.
59186058
bool isBoxValue() { return semant == ConstituentSemantics::BoxValue; }
59196059

@@ -6100,9 +6240,11 @@ mlir::Value Fortran::lower::createSubroutineCall(
61006240
// Elemental user defined assignment has special requirements to deal with
61016241
// LHS/RHS overlaps. See 10.2.1.5 p2.
61026242
if (isUserDefAssignment)
6103-
TODO(converter.getCurrentLocation(), "elemental user defined assignment");
6104-
ArrayExprLowering::lowerArrayElementalSubroutine(converter, symMap, stmtCtx,
6105-
call);
6243+
ArrayExprLowering::lowerElementalUserAssignment(converter, symMap,
6244+
stmtCtx, call);
6245+
else
6246+
ArrayExprLowering::lowerArrayElementalSubroutine(converter, symMap,
6247+
stmtCtx, call);
61066248
return mlir::Value{};
61076249
}
61086250
// FIXME: The non elemental user defined assignment case with array arguments

flang/lib/Optimizer/Builder/Character.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,3 +705,19 @@ fir::factory::CharacterExprHelper::readLengthFromBox(mlir::Value box) {
705705
}
706706
return size;
707707
}
708+
709+
mlir::Value fir::factory::CharacterExprHelper::getLength(mlir::Value memref) {
710+
auto memrefType = memref.getType();
711+
auto charType = recoverCharacterType(memrefType);
712+
assert(charType && "must be a character type");
713+
if (charType.hasConstantLen())
714+
return builder.createIntegerConstant(loc, builder.getCharacterLengthType(),
715+
charType.getLen());
716+
if (memrefType.isa<fir::BoxType>())
717+
return readLengthFromBox(memref);
718+
if (memrefType.isa<fir::BoxCharType>())
719+
return createUnboxChar(memref).second;
720+
721+
// Length cannot be deduced from memref.
722+
return {};
723+
}

0 commit comments

Comments
 (0)