@@ -2416,10 +2416,17 @@ class ArrayExprLowering {
2416
2416
};
2417
2417
2418
2418
using ExtValue = fir::ExtendedValue;
2419
- using IterSpace = const IterationSpace &; // active iteration space
2420
- using CC = std::function<ExtValue(IterSpace)>; // current continuation
2421
- using PC =
2422
- std::function<IterationSpace(IterSpace)>; // projection continuation
2419
+ // / Active iteration space.
2420
+ using IterSpace = const IterationSpace &;
2421
+ // / Current continuation. Function that we generate IR for a single iteration
2422
+ // / of the pending iterative loop structure.
2423
+ using CC = std::function<ExtValue(IterSpace)>;
2424
+ // / Projection continuation. Function that will project one iteration space
2425
+ // / into another.
2426
+ using PC = std::function<IterationSpace(IterSpace)>;
2427
+ // / Loop bounds continuation. Function that will generate IR to compute loop
2428
+ // / bounds in a future context.
2429
+ using LBC = std::function<llvm::SmallVector<mlir::Value>()>;
2423
2430
using ArrayBaseTy =
2424
2431
std::variant<std::monostate, const Fortran::evaluate::ArrayRef *,
2425
2432
const Fortran::evaluate::DataRef *>;
@@ -2690,13 +2697,15 @@ class ArrayExprLowering {
2690
2697
}
2691
2698
}
2692
2699
2700
+ // / Returns true iff the Ev::Shape is constant.
2693
2701
static bool evalShapeIsConstant (const Fortran::evaluate::Shape &shape) {
2694
2702
for (const auto &s : shape)
2695
2703
if (!s || !Fortran::evaluate::IsConstantExpr (*s))
2696
2704
return false ;
2697
2705
return true ;
2698
2706
}
2699
2707
2708
+ // / Convert an Ev::Shape to IR values.
2700
2709
void convertFEShape (const Fortran::evaluate::Shape &shape,
2701
2710
llvm::SmallVectorImpl<mlir::Value> &result) {
2702
2711
if (evalShapeIsConstant (shape)) {
@@ -2835,7 +2844,7 @@ class ArrayExprLowering {
2835
2844
// / this returns any implicit shape component, if it exists.
2836
2845
llvm::SmallVector<mlir::Value> genIterationShape () {
2837
2846
if (explicitSpace)
2838
- return explicitImpliedShape ;
2847
+ return {} ;
2839
2848
// Use the precomputed destination shape.
2840
2849
if (!destShape.empty ())
2841
2850
return destShape;
@@ -3203,6 +3212,51 @@ class ArrayExprLowering {
3203
3212
return {indices, loops[0 ]};
3204
3213
}
3205
3214
3215
+ llvm::SmallVector<mlir::Value> genImplicitLoopBounds (
3216
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &e) {
3217
+ struct Filter : public Fortran ::evaluate::AnyTraverse<
3218
+ Filter, std::optional<llvm::SmallVector<mlir::Value>>> {
3219
+ using Base = Fortran::evaluate::AnyTraverse<
3220
+ Filter, std::optional<llvm::SmallVector<mlir::Value>>>;
3221
+ using Base::operator ();
3222
+
3223
+ Filter (const llvm::SmallVector<mlir::Value> init)
3224
+ : Base(*this ), bounds(init) {}
3225
+
3226
+ std::optional<llvm::SmallVector<mlir::Value>>
3227
+ operator ()(const Fortran::evaluate::ArrayRef &ref) {
3228
+ if ((ref.base ().IsSymbol () || ref.base ().Rank () == 0 ) &&
3229
+ ref.Rank () > 0 && !ref.subscript ().empty ()) {
3230
+ assert (ref.subscript ().size () == bounds.size ());
3231
+ llvm::SmallVector<mlir::Value> result;
3232
+ auto bdIter = bounds.begin ();
3233
+ for (auto ss : ref.subscript ()) {
3234
+ mlir::Value bound = *bdIter++;
3235
+ std::visit (Fortran::common::visitors{
3236
+ [&](const Fortran::evaluate::Triplet &triple) {
3237
+ result.push_back (bound);
3238
+ },
3239
+ [&](const auto &intExpr) {
3240
+ if (intExpr.value ().Rank () > 0 )
3241
+ result.push_back (bound);
3242
+ }},
3243
+ ss.u );
3244
+ }
3245
+ return {result};
3246
+ }
3247
+ return {};
3248
+ }
3249
+
3250
+ llvm::SmallVector<mlir::Value> bounds;
3251
+ };
3252
+
3253
+ auto originalShape = getShape (converter.genType (e));
3254
+ Filter filter (originalShape);
3255
+ if (auto res = filter (e))
3256
+ return *res;
3257
+ return originalShape;
3258
+ }
3259
+
3206
3260
void genMasks () {
3207
3261
auto loc = getLoc ();
3208
3262
// Lower explicit mask expressions, if any.
@@ -3219,8 +3273,8 @@ class ArrayExprLowering {
3219
3273
for (const auto *e : masks->getExprs ())
3220
3274
if (e && !masks->isLowered (e)) {
3221
3275
auto extents = genExplicitExtents ();
3222
- extents. append (explicitImpliedShape. rbegin (),
3223
- explicitImpliedShape .rend ());
3276
+ auto loopBounds = genImplicitLoopBounds (*e);
3277
+ extents. append (loopBounds. rbegin (), loopBounds .rend ());
3224
3278
// Allocate a temporary to cache the mask results.
3225
3279
auto tmpShape = builder.consShape (loc, extents);
3226
3280
auto tmp = createAndLoadSomeArrayTemp (
@@ -3234,7 +3288,8 @@ class ArrayExprLowering {
3234
3288
// Evaluate like any other nested array expression.
3235
3289
ArrayExprLowering ael{converter, masks->stmtContext (), symMap,
3236
3290
ConstituentSemantics::ProjectedCopyInCopyOut};
3237
- ael.lowerArrayAssignment (tmp, *e, indices, explicitImpliedShape);
3291
+ ael.lowerArrayAssignment (tmp, *e, indices,
3292
+ explicitImpliedLoopBounds.getValue ()());
3238
3293
masks->bind (e, tmp.memref (), tmpShape);
3239
3294
builder.setInsertionPointAfter (loop0);
3240
3295
builder.create <fir::ArrayMergeStoreOp>(loc, tmp, loop0.getResult (0 ),
@@ -3293,14 +3348,18 @@ class ArrayExprLowering {
3293
3348
llvm::SmallVector<fir::DoLoopOp> loops;
3294
3349
llvm::SmallVector<mlir::Value> explicitOffsets;
3295
3350
// FORALL loops are outermost.
3296
- if (explicitSpace)
3351
+ if (explicitSpace) {
3297
3352
genExplicitIterSpace (loops, explicitOffsets, innerArg);
3353
+ if (explicitImpliedLoopBounds.hasValue ())
3354
+ loopUppers = explicitImpliedLoopBounds.getValue ()();
3355
+ }
3298
3356
3299
3357
// Now handle the implicit loops.
3300
3358
const auto loopFirst = loops.size ();
3301
3359
const auto loopDepth = loopUppers.size ();
3302
3360
llvm::SmallVector<mlir::Value> ivars;
3303
3361
if (loopDepth > 0 ) {
3362
+ auto *startBlock = builder.getBlock ();
3304
3363
for (auto i : llvm::enumerate (llvm::reverse (loopUppers))) {
3305
3364
if (i.index () > 0 ) {
3306
3365
assert (!loops.empty ());
@@ -3315,8 +3374,11 @@ class ArrayExprLowering {
3315
3374
}
3316
3375
// Add the fir.result for all loops except the innermost one. We must also
3317
3376
// terminate the innermost explicit bounds loop here as well.
3318
- for (std::remove_const_t <decltype (loopFirst)> i =
3319
- loopFirst ? loopFirst - 1 : 0 ;
3377
+ if (loopFirst > 0 ) {
3378
+ builder.setInsertionPointToEnd (startBlock);
3379
+ builder.create <fir::ResultOp>(loc, loops[loopFirst].getResult (0 ));
3380
+ }
3381
+ for (std::remove_const_t <decltype (loopFirst)> i = loopFirst;
3320
3382
i + 1 < loopFirst + loopDepth; ++i) {
3321
3383
builder.setInsertionPointToEnd (loops[i].getBody ());
3322
3384
builder.create <fir::ResultOp>(loc, loops[i + 1 ].getResult (0 ));
@@ -3369,17 +3431,18 @@ class ArrayExprLowering {
3369
3431
// structure is produced.
3370
3432
auto maskExprs = masks->getExprs ();
3371
3433
const auto size = maskExprs.size () - 1 ;
3372
- for (std::remove_const_t <decltype (size)> i = 0 ; i < size; ++i) {
3373
- auto ifOp = builder.create <fir::IfOp>(
3374
- loc, mlir::TypeRange{innerArg.getType ()},
3375
- fir::getBase (
3376
- genCond (masks->getBindingWithShape (maskExprs[i]), iters)),
3377
- /* withElseRegion=*/ true );
3378
- builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
3379
- builder.setInsertionPointToStart (&ifOp.thenRegion ().front ());
3380
- builder.create <fir::ResultOp>(loc, innerArg);
3381
- builder.setInsertionPointToStart (&ifOp.elseRegion ().front ());
3382
- }
3434
+ for (std::remove_const_t <decltype (size)> i = 0 ; i < size; ++i)
3435
+ if (maskExprs[i]) {
3436
+ auto ifOp = builder.create <fir::IfOp>(
3437
+ loc, mlir::TypeRange{innerArg.getType ()},
3438
+ fir::getBase (
3439
+ genCond (masks->getBindingWithShape (maskExprs[i]), iters)),
3440
+ /* withElseRegion=*/ true );
3441
+ builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
3442
+ builder.setInsertionPointToStart (&ifOp.thenRegion ().front ());
3443
+ builder.create <fir::ResultOp>(loc, innerArg);
3444
+ builder.setInsertionPointToStart (&ifOp.elseRegion ().front ());
3445
+ }
3383
3446
3384
3447
// The last condition is either non-negated or unconditionally negated.
3385
3448
if (maskExprs[size]) {
@@ -4268,11 +4331,8 @@ class ArrayExprLowering {
4268
4331
template <typename A>
4269
4332
std::pair<CC, mlir::Type> raiseRankedBase (const A &x) {
4270
4333
auto result = raiseBase (x);
4271
- if (isProjectedCopyInCopyOut ()) {
4272
- auto optShape = Fortran::evaluate::GetShape (x);
4273
- assert (optShape.has_value ());
4274
- convertFEShape (*optShape, explicitImpliedShape);
4275
- }
4334
+ if (isProjectedCopyInCopyOut ())
4335
+ explicitImpliedLoopBounds = [=]() { return getShape (x); };
4276
4336
return result;
4277
4337
}
4278
4338
template <typename A>
@@ -4303,11 +4363,8 @@ class ArrayExprLowering {
4303
4363
std::pair<CC, mlir::Type> raiseRankedComponent (llvm::Optional<CC> cc,
4304
4364
const A &x, mlir::Type inTy) {
4305
4365
auto result = raiseComponent (cc, x, inTy, false );
4306
- if (isProjectedCopyInCopyOut ()) {
4307
- auto optShape = Fortran::evaluate::GetShape (x);
4308
- assert (optShape.has_value ());
4309
- convertFEShape (*optShape, explicitImpliedShape);
4310
- }
4366
+ if (isProjectedCopyInCopyOut ())
4367
+ explicitImpliedLoopBounds = [=]() { return getShape (x); };
4311
4368
return result;
4312
4369
}
4313
4370
@@ -4376,8 +4433,7 @@ class ArrayExprLowering {
4376
4433
auto &sym = base.GetFirstSymbol ();
4377
4434
if (x.Rank () > 0 || accessUsesControlVariable ()) {
4378
4435
auto [fopt2, ty2] = raiseBase (sym);
4379
- return RaiseRT{fopt2, fir::unwrapSequenceType (ty2), false ,
4380
- x.Rank () > 0 };
4436
+ return RaiseRT{fopt2, ty2, false , x.Rank () > 0 };
4381
4437
}
4382
4438
return RaiseRT{llvm::None, mlir::Type{}, false , false };
4383
4439
}
@@ -4394,11 +4450,54 @@ class ArrayExprLowering {
4394
4450
}(),
4395
4451
x);
4396
4452
}
4453
+ static mlir::Type unwrapBoxEleTy (mlir::Type ty) {
4454
+ if (auto boxTy = ty.dyn_cast <fir::BoxType>()) {
4455
+ ty = boxTy.getEleTy ();
4456
+ if (auto refTy = fir::dyn_cast_ptrEleTy (ty))
4457
+ ty = refTy;
4458
+ }
4459
+ return ty;
4460
+ }
4461
+ llvm::SmallVector<mlir::Value> getShape (mlir::Type ty) {
4462
+ llvm::SmallVector<mlir::Value> result;
4463
+ ty = unwrapBoxEleTy (ty);
4464
+ auto loc = getLoc ();
4465
+ auto idxTy = builder.getIndexType ();
4466
+ for (auto extent : ty.cast <fir::SequenceType>().getShape ()) {
4467
+ auto v = extent == fir::SequenceType::getUnknownExtent ()
4468
+ ? builder.create <fir::UndefOp>(loc, idxTy).getResult ()
4469
+ : builder.createIntegerConstant (loc, idxTy, extent);
4470
+ result.push_back (v);
4471
+ }
4472
+ return result;
4473
+ }
4474
+ llvm::SmallVector<mlir::Value>
4475
+ getShape (const Fortran::semantics::SymbolRef &x) {
4476
+ if (x.get ().Rank () == 0 )
4477
+ return {};
4478
+ return getFrontEndShape (x);
4479
+ }
4480
+ template <typename A>
4481
+ llvm::SmallVector<mlir::Value> getShape (const A &x) {
4482
+ if (x.Rank () == 0 )
4483
+ return {};
4484
+ return getFrontEndShape (x);
4485
+ }
4486
+ template <typename A>
4487
+ llvm::SmallVector<mlir::Value> getFrontEndShape (const A &x) {
4488
+ if (auto optShape = Fortran::evaluate::GetShape (x)) {
4489
+ llvm::SmallVector<mlir::Value> result;
4490
+ convertFEShape (*optShape, result);
4491
+ if (!result.empty ())
4492
+ return result;
4493
+ }
4494
+ return {};
4495
+ }
4397
4496
RaiseRT raiseSubscript (const RaiseRT &tup,
4398
4497
const Fortran::evaluate::ArrayRef &x) {
4399
4498
auto fopt = std::get<llvm::Optional<CC>>(tup);
4400
4499
if (fopt.hasValue ()) {
4401
- auto ty = std::get<mlir::Type>(tup);
4500
+ auto arrTy = std::get<mlir::Type>(tup);
4402
4501
auto prevRanked = std::get<2 >(tup);
4403
4502
auto ranked = std::get<3 >(tup);
4404
4503
auto lambda = fopt.getValue ();
@@ -4412,30 +4511,64 @@ class ArrayExprLowering {
4412
4511
// from the explicit space, then those dimensions should not be
4413
4512
// considered as contributing to the implied part of the iteration
4414
4513
// space.
4415
- if (explicitImpliedShape.empty ()) {
4416
- assert (destination && " destination must be set" );
4417
- auto feShape = getShape (destination);
4514
+ if (!explicitImpliedLoopBounds.hasValue ()) {
4418
4515
if (subs.empty ()) {
4419
- explicitImpliedShape. assign (feShape) ;
4516
+ explicitImpliedLoopBounds = [=]() { return getShape (x); } ;
4420
4517
} else {
4421
- unsigned ii = 0 ;
4518
+ auto desShape = getShape (x) ;
4422
4519
unsigned vi = 0 ;
4423
- vectorCoor.resize (feShape .size ());
4520
+ vectorCoor.resize (desShape .size ());
4424
4521
// Filter out subscripts that are scalar expressions. If it is a
4425
- // scalar expression it is either loop-invariant or a function of
4426
- // the explicit loop control variables.
4427
- for (const auto &ss : subs)
4522
+ // scalar expression it is either loop-invariant or a function
4523
+ // of the explicit loop control variables.
4524
+ for (const auto &ss : subs) {
4428
4525
if (auto *intExpr = std::get_if<
4429
- Fortran::evaluate::IndirectSubscriptIntegerExpr>(&ss.u )) {
4430
- if (intExpr->value ().Rank () > 0 ) {
4431
- explicitImpliedShape.push_back (feShape[ii++]);
4526
+ Fortran::evaluate::IndirectSubscriptIntegerExpr>(&ss.u ))
4527
+ if (intExpr->value ().Rank () > 0 )
4432
4528
vectorCoor[vi++] = genarr (intExpr->value ());
4433
- }
4434
- } else {
4435
- // This is a triple which may be using an explicit control
4436
- // variable.
4437
- explicitImpliedShape.push_back (feShape[ii++]);
4438
- }
4529
+ }
4530
+ explicitImpliedLoopBounds = [=]() {
4531
+ llvm::SmallVector<mlir::Value> result;
4532
+ unsigned ii = 0 ;
4533
+ for (const auto &ss : subs)
4534
+ std::visit (
4535
+ Fortran::common::visitors{
4536
+ [&](const Fortran::evaluate::
4537
+ IndirectSubscriptIntegerExpr &intExpr) {
4538
+ if (intExpr.value ().Rank () > 0 )
4539
+ result.push_back (builder.createConvert (
4540
+ loc, idxTy, desShape[ii++]));
4541
+ },
4542
+ [&](const Fortran::evaluate::Triplet &triple) {
4543
+ // This is a triple which may be using an
4544
+ // explicit control variable.
4545
+ auto ou = triple.upper ();
4546
+ auto up = builder.createConvert (
4547
+ loc, idxTy,
4548
+ ou.has_value () ? fir::getBase (asScalar (*ou))
4549
+ : desShape[ii]);
4550
+ auto ol = triple.lower ();
4551
+ auto lo =
4552
+ ol.has_value ()
4553
+ ? builder.createConvert (
4554
+ loc, idxTy, fir::getBase (asScalar (*ol)))
4555
+ : builder.createIntegerConstant (loc, idxTy,
4556
+ 1 );
4557
+ auto step = builder.createConvert (
4558
+ loc, idxTy,
4559
+ fir::getBase (asScalar (triple.stride ())));
4560
+ auto diff = builder.create <mlir::SubIOp>(loc, up, lo);
4561
+ auto sum =
4562
+ builder.create <mlir::AddIOp>(loc, diff, step);
4563
+ mlir::Value count =
4564
+ builder.create <mlir::SignedDivIOp>(loc, sum,
4565
+ step);
4566
+ result.push_back (count);
4567
+ ii++;
4568
+ }},
4569
+ ss.u );
4570
+ return result;
4571
+ };
4439
4572
}
4440
4573
}
4441
4574
}
@@ -4502,6 +4635,7 @@ class ArrayExprLowering {
4502
4635
}
4503
4636
return newIters;
4504
4637
};
4638
+ auto ty = fir::unwrapSequenceType (unwrapBoxEleTy (arrTy));
4505
4639
return RaiseRT{[=](IterSpace iters) { return lambda (pc (iters)); }, ty,
4506
4640
prevRanked, ranked};
4507
4641
}
@@ -4588,7 +4722,9 @@ class ArrayExprLowering {
4588
4722
4589
4723
static mlir::Type adjustedArraySubtype (mlir::Type ty,
4590
4724
mlir::ValueRange indices) {
4591
- return adjustedArrayElementType (fir::applyPathToType (ty, indices));
4725
+ auto pathTy = fir::applyPathToType (ty, indices);
4726
+ assert (pathTy && " indices failed to apply to type" );
4727
+ return adjustedArrayElementType (pathTy);
4592
4728
}
4593
4729
4594
4730
// / Build an ExtendedValue from a fir.array<?x...?xT> without actually
@@ -5469,7 +5605,7 @@ class ArrayExprLowering {
5469
5605
// / Even in an explicitly defined iteration space, one can have an
5470
5606
// / assignment with rank > 0 and thus an implied shape on a component in the
5471
5607
// / path.
5472
- llvm::SmallVector<mlir::Value> explicitImpliedShape ;
5608
+ llvm::Optional<LBC> explicitImpliedLoopBounds ;
5473
5609
Fortran::lower::ImplicitIterSpace *masks = nullptr ;
5474
5610
ConstituentSemantics semant = ConstituentSemantics::RefTransparent;
5475
5611
bool inSlice = false ;
0 commit comments