@@ -136,6 +136,10 @@ enum class ConstituentSemantics {
136
136
// Similar to CopyInCopyOut but `a` may be a transient projection (rather than
137
137
// a whole array).
138
138
ProjectedCopyInCopyOut,
139
+ // Similar to ProjectedCopyInCopyOut, except the merge value is not assigned
140
+ // automatically by the framework. Instead, and address for `[xs]` is made
141
+ // accessible so that custom assignments to `[xs]` can be implemented.
142
+ CustomCopyInCopyOut,
139
143
// Referentially opaque. Refers to the address of `x_i`.
140
144
RefOpaque
141
145
};
@@ -219,18 +223,19 @@ static fir::ExtendedValue arrayLoadExtValue(fir::FirOpBuilder &builder,
219
223
// Recover the extended value from the load.
220
224
assert (!load.slice () && " slice is not allowed" );
221
225
auto arrTy = load.getType ();
222
- auto idxTy = builder.getIndexType ();
223
226
if (!path.empty ()) {
224
227
auto ty = fir::applyPathToType (arrTy, path);
225
228
if (!ty)
226
229
fir::emitFatalError (loc, " path does not apply to type" );
227
230
if (!ty.isa <fir::SequenceType>()) {
228
- if (auto charTy = ty.dyn_cast <fir::CharacterType>()) {
229
- // ???: Is this in CharacterExprHelper?
230
- auto len = charTy.hasConstantLen ()
231
- ? builder.createIntegerConstant (
232
- loc, idxTy, ty.cast <fir::CharacterType>().getLen ())
233
- : load.typeparams ()[0 ];
231
+ if (fir::isa_char (ty)) {
232
+ auto len = fir::factory::CharacterExprHelper{builder, loc}.getLength (
233
+ load.memref ());
234
+ if (!len) {
235
+ assert (load.typeparams ().size () == 1 &&
236
+ " length must be in array_load" );
237
+ len = load.typeparams ()[0 ];
238
+ }
234
239
return fir::CharBoxValue{newBase, len};
235
240
}
236
241
return newBase;
@@ -250,11 +255,13 @@ static fir::ExtendedValue arrayLoadExtValue(fir::FirOpBuilder &builder,
250
255
}
251
256
auto extents = fir::factory::getExtents (load.shape ());
252
257
auto lbounds = fir::factory::getOrigins (load.shape ());
253
- if (auto charTy = eleTy.dyn_cast <fir::CharacterType>()) {
254
- auto len = charTy.hasConstantLen ()
255
- ? builder.createIntegerConstant (
256
- loc, idxTy, eleTy.cast <fir::CharacterType>().getLen ())
257
- : load.typeparams ()[0 ];
258
+ if (fir::isa_char (eleTy)) {
259
+ auto len = fir::factory::CharacterExprHelper{builder, loc}.getLength (
260
+ load.memref ());
261
+ if (!len) {
262
+ assert (load.typeparams ().size () == 1 && " length must be in array_load" );
263
+ len = load.typeparams ()[0 ];
264
+ }
258
265
return fir::CharArrayBoxValue{newBase, len, extents, lbounds};
259
266
}
260
267
if (load.typeparams ().empty ()) {
@@ -264,6 +271,25 @@ static fir::ExtendedValue arrayLoadExtValue(fir::FirOpBuilder &builder,
264
271
" properties are explicit, assumed, deferred, or ?" );
265
272
}
266
273
274
+ // / Convert the result of a fir.array_modify to an ExtendedValue given the
275
+ // / related fir.array_load.
276
+ static fir::ExtendedValue arrayModifyToExv (fir::FirOpBuilder &builder,
277
+ mlir::Location loc,
278
+ fir::ArrayLoadOp load,
279
+ mlir::Value elementAddr) {
280
+ auto eleTy = fir::unwrapPassByRefType (elementAddr.getType ());
281
+ if (fir::isa_char (eleTy)) {
282
+ auto len = fir::factory::CharacterExprHelper{builder, loc}.getLength (
283
+ load.memref ());
284
+ if (!len) {
285
+ assert (load.typeparams ().size () == 1 && " length must be in array_load" );
286
+ len = load.typeparams ()[0 ];
287
+ }
288
+ return fir::CharBoxValue{elementAddr, len};
289
+ }
290
+ return elementAddr;
291
+ }
292
+
267
293
// / Is this a call to an elemental procedure with at least one array argument ?
268
294
static bool
269
295
isElementalProcWithArrayArgs (const Fortran::evaluate::ProcedureRef &procRef) {
@@ -1835,7 +1861,7 @@ class ScalarExprLowering {
1835
1861
});
1836
1862
// / Result lengths parameters should not be provided to box storage
1837
1863
// / allocation and save_results, but they are still useful information to
1838
- // / keep in the ExtentdedValue if non-deferred.
1864
+ // / keep in the ExtendedValue if non-deferred.
1839
1865
if (!type.isa <fir::BoxType>())
1840
1866
resultLengths = lengths;
1841
1867
auto temp =
@@ -2418,6 +2444,50 @@ static mlir::Value adjustedArrayElement(mlir::Location loc,
2418
2444
return builder.createConvert (loc, adjustedArrayElementType (eleTy), val);
2419
2445
}
2420
2446
2447
+ // / Helper to generate calls to scalar user defined assignment procedures.
2448
+ static void genScalarUserDefinedAssignmentCall (fir::FirOpBuilder &builder,
2449
+ mlir::Location loc,
2450
+ mlir::FuncOp func,
2451
+ const fir::ExtendedValue &lhs,
2452
+ const fir::ExtendedValue &rhs) {
2453
+ auto prepareUserDefinedArg =
2454
+ [](fir::FirOpBuilder &builder, mlir::Location loc,
2455
+ const fir::ExtendedValue &value, mlir::Type argType) -> mlir::Value {
2456
+ if (argType.isa <fir::BoxCharType>()) {
2457
+ const auto *charBox = value.getCharBox ();
2458
+ assert (charBox && " argument type mismatch in elemental user assignment" );
2459
+ return fir::factory::CharacterExprHelper{builder, loc}.createEmbox (
2460
+ *charBox);
2461
+ }
2462
+ if (argType.isa <fir::BoxType>()) {
2463
+ auto box = builder.createBox (loc, value);
2464
+ return builder.createConvert (loc, argType, box);
2465
+ }
2466
+ // Simple pass by address.
2467
+ auto argBaseType = fir::unwrapRefType (argType);
2468
+ assert (!fir::hasDynamicSize (argBaseType));
2469
+ auto from = fir::getBase (value);
2470
+ if (argBaseType != fir::unwrapRefType (from.getType ())) {
2471
+ // With logicals, it is possible that from is i1 here.
2472
+ if (fir::isa_ref_type (from.getType ()))
2473
+ from = builder.create <fir::LoadOp>(loc, from);
2474
+ from = builder.createConvert (loc, argBaseType, from);
2475
+ }
2476
+ if (!fir::isa_ref_type (from.getType ())) {
2477
+ auto temp = builder.createTemporary (loc, argBaseType);
2478
+ builder.create <fir::StoreOp>(loc, from, temp);
2479
+ from = temp;
2480
+ }
2481
+ return builder.createConvert (loc, argType, from);
2482
+ };
2483
+ assert (func.getNumArguments () == 2 );
2484
+ auto lhsType = func.getType ().getInput (0 );
2485
+ auto rhsType = func.getType ().getInput (1 );
2486
+ auto lhsArg = prepareUserDefinedArg (builder, loc, lhs, lhsType);
2487
+ auto rhsArg = prepareUserDefinedArg (builder, loc, rhs, rhsType);
2488
+ builder.create <fir::CallOp>(loc, func, mlir::ValueRange{lhsArg, rhsArg});
2489
+ }
2490
+
2421
2491
// ===----------------------------------------------------------------------===//
2422
2492
//
2423
2493
// Lowering of scalar expressions in an explicit iteration space context.
@@ -2477,6 +2547,12 @@ class ScalarArrayExprLowering {
2477
2547
auto offset = expSpace.argPosition (oldInnerArg);
2478
2548
expSpace.setInnerArg (offset, fir::getBase (lexv));
2479
2549
builder.create <fir::ResultOp>(getLoc (), fir::getBase (lexv));
2550
+ } else if (auto updateOp = mlir::dyn_cast<fir::ArrayModifyOp>(
2551
+ fir::getBase (lexv).getDefiningOp ())) {
2552
+ auto oldInnerArg = updateOp.sequence ();
2553
+ auto offset = expSpace.argPosition (oldInnerArg);
2554
+ expSpace.setInnerArg (offset, fir::getBase (lexv));
2555
+ builder.create <fir::ResultOp>(getLoc (), fir::getBase (lexv));
2480
2556
}
2481
2557
return lexv;
2482
2558
}
@@ -3619,6 +3695,8 @@ class ArrayExprLowering {
3619
3695
return abstractArrayExtValue (iterSpace.outerResult ());
3620
3696
}
3621
3697
3698
+ // / Lower an elemental subroutine call with at least one array argument.
3699
+ // / Not for user defined assignments.
3622
3700
static void lowerArrayElementalSubroutine (
3623
3701
Fortran::lower::AbstractConverter &converter,
3624
3702
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
@@ -3628,7 +3706,6 @@ class ArrayExprLowering {
3628
3706
ael.lowerArrayElementalSubroutine (call);
3629
3707
}
3630
3708
3631
- // ! Not for user defined assignment elemental subroutine.
3632
3709
void lowerArrayElementalSubroutine (
3633
3710
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &call) {
3634
3711
auto f = genarr (call);
@@ -3638,6 +3715,48 @@ class ArrayExprLowering {
3638
3715
builder.restoreInsertionPoint (insPt);
3639
3716
}
3640
3717
3718
+ static void lowerElementalUserAssignment (
3719
+ Fortran::lower::AbstractConverter &converter,
3720
+ Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
3721
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &call) {
3722
+ ArrayExprLowering ael (converter, stmtCtx, symMap,
3723
+ ConstituentSemantics::CustomCopyInCopyOut);
3724
+ auto procRef = std::get<Fortran::evaluate::ProcedureRef>(call.u );
3725
+ assert (procRef.arguments ().size () == 2 );
3726
+ const auto *lhs = procRef.arguments ()[0 ].value ().UnwrapExpr ();
3727
+ const auto *rhs = procRef.arguments ()[1 ].value ().UnwrapExpr ();
3728
+ assert (lhs && rhs &&
3729
+ " user defined assignment arguments must be expressions" );
3730
+ auto func = Fortran::lower::CallerInterface (procRef, converter).getFuncOp ();
3731
+ ael.lowerElementalUserAssignment (func, *lhs, *rhs);
3732
+ }
3733
+
3734
+ void lowerElementalUserAssignment (
3735
+ mlir::FuncOp userAssignment,
3736
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &lhs,
3737
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &rhs) {
3738
+ auto loc = getLoc ();
3739
+ PushSemantics (ConstituentSemantics::CustomCopyInCopyOut);
3740
+ auto genArrayModify = genarr (lhs);
3741
+ ccStoreToDest = [=](IterSpace iters) -> ExtValue {
3742
+ auto modifiedArray = genArrayModify (iters);
3743
+ auto arrayModify = mlir::dyn_cast_or_null<fir::ArrayModifyOp>(
3744
+ fir::getBase (modifiedArray).getDefiningOp ());
3745
+ assert (arrayModify && " must be created by ArrayModifyOp" );
3746
+ auto lhs =
3747
+ arrayModifyToExv (builder, loc, destination, arrayModify.getResult (0 ));
3748
+ genScalarUserDefinedAssignmentCall (builder, loc, userAssignment, lhs,
3749
+ iters.elementExv ());
3750
+ return modifiedArray;
3751
+ };
3752
+ determineShapeOfDest (lhs);
3753
+ semant = ConstituentSemantics::RefTransparent;
3754
+ auto exv = lowerArrayExpression (rhs);
3755
+ builder.create <fir::ArrayMergeStoreOp>(
3756
+ loc, destination, fir::getBase (exv), destination.memref (),
3757
+ destination.slice (), destination.typeparams ());
3758
+ }
3759
+
3641
3760
// / Compute the shape of a slice.
3642
3761
llvm::SmallVector<mlir::Value> computeSliceShape (mlir::Value slice) {
3643
3762
llvm::SmallVector<mlir::Value> slicedShape;
@@ -4994,6 +5113,23 @@ class ArrayExprLowering {
4994
5113
return abstractArrayExtValue (arrUpdate);
4995
5114
};
4996
5115
}
5116
+ if (isCustomCopyInCopyOut ()) {
5117
+ // Create an array_modify to get the LHS element address and indicate
5118
+ // the assignment, the actual assignment must be implemented in
5119
+ // ccStoreToDest.
5120
+ destination = arrLoad;
5121
+ return [=](IterSpace iters) -> ExtValue {
5122
+ auto innerArg = iters.innerArgument ();
5123
+ auto resTy = innerArg.getType ();
5124
+ auto eleTy = fir::applyPathToType (resTy, iters.iterVec ());
5125
+ auto refEleTy =
5126
+ fir::isa_ref_type (eleTy) ? eleTy : builder.getRefType (eleTy);
5127
+ auto arrModify = builder.create <fir::ArrayModifyOp>(
5128
+ loc, mlir::TypeRange{refEleTy, resTy}, innerArg, iters.iterVec (),
5129
+ destination.typeparams ());
5130
+ return abstractArrayExtValue (arrModify.getResult (1 ));
5131
+ };
5132
+ }
4997
5133
if (isCopyInCopyOut ()) {
4998
5134
// Semantics are copy-in copy-out.
4999
5135
// The continuation simply forwards the result of the `array_load` Op,
@@ -5914,6 +6050,10 @@ class ArrayExprLowering {
5914
6050
return semant == ConstituentSemantics::ProjectedCopyInCopyOut;
5915
6051
}
5916
6052
6053
+ bool isCustomCopyInCopyOut () {
6054
+ return semant == ConstituentSemantics::CustomCopyInCopyOut;
6055
+ }
6056
+
5917
6057
// / Array appears in a context where it must be boxed.
5918
6058
bool isBoxValue () { return semant == ConstituentSemantics::BoxValue; }
5919
6059
@@ -6100,9 +6240,11 @@ mlir::Value Fortran::lower::createSubroutineCall(
6100
6240
// Elemental user defined assignment has special requirements to deal with
6101
6241
// LHS/RHS overlaps. See 10.2.1.5 p2.
6102
6242
if (isUserDefAssignment)
6103
- TODO (converter.getCurrentLocation (), " elemental user defined assignment" );
6104
- ArrayExprLowering::lowerArrayElementalSubroutine (converter, symMap, stmtCtx,
6105
- call);
6243
+ ArrayExprLowering::lowerElementalUserAssignment (converter, symMap,
6244
+ stmtCtx, call);
6245
+ else
6246
+ ArrayExprLowering::lowerArrayElementalSubroutine (converter, symMap,
6247
+ stmtCtx, call);
6106
6248
return mlir::Value{};
6107
6249
}
6108
6250
// FIXME: The non elemental user defined assignment case with array arguments
0 commit comments