@@ -2412,10 +2412,17 @@ class ArrayExprLowering {
2412
2412
};
2413
2413
2414
2414
using ExtValue = fir::ExtendedValue;
2415
- using IterSpace = const IterationSpace &; // active iteration space
2416
- using CC = std::function<ExtValue(IterSpace)>; // current continuation
2417
- using PC =
2418
- std::function<IterationSpace(IterSpace)>; // projection continuation
2415
+ // / Active iteration space.
2416
+ using IterSpace = const IterationSpace &;
2417
+ // / Current continuation. Function that will generate IR for a single
2418
+ // / iteration of the pending iterative loop structure.
2419
+ using CC = std::function<ExtValue(IterSpace)>;
2420
+ // / Projection continuation. Function that will project one iteration space
2421
+ // / into another.
2422
+ using PC = std::function<IterationSpace(IterSpace)>;
2423
+ // / Loop bounds continuation. Function that will generate IR to compute loop
2424
+ // / bounds in a future context.
2425
+ using LBC = std::function<llvm::SmallVector<mlir::Value>()>;
2419
2426
using ArrayBaseTy =
2420
2427
std::variant<std::monostate, const Fortran::evaluate::ArrayRef *,
2421
2428
const Fortran::evaluate::DataRef *>;
@@ -2686,13 +2693,15 @@ class ArrayExprLowering {
2686
2693
}
2687
2694
}
2688
2695
2696
+ // / Returns true iff the Ev::Shape is constant.
2689
2697
static bool evalShapeIsConstant (const Fortran::evaluate::Shape &shape) {
2690
2698
for (const auto &s : shape)
2691
2699
if (!s || !Fortran::evaluate::IsConstantExpr (*s))
2692
2700
return false ;
2693
2701
return true ;
2694
2702
}
2695
2703
2704
+ // / Convert an Ev::Shape to IR values.
2696
2705
void convertFEShape (const Fortran::evaluate::Shape &shape,
2697
2706
llvm::SmallVectorImpl<mlir::Value> &result) {
2698
2707
if (evalShapeIsConstant (shape)) {
@@ -2831,7 +2840,7 @@ class ArrayExprLowering {
2831
2840
// / this returns any implicit shape component, if it exists.
2832
2841
llvm::SmallVector<mlir::Value> genIterationShape () {
2833
2842
if (explicitSpace)
2834
- return explicitImpliedShape ;
2843
+ return {} ;
2835
2844
// Use the precomputed destination shape.
2836
2845
if (!destShape.empty ())
2837
2846
return destShape;
@@ -3199,6 +3208,51 @@ class ArrayExprLowering {
3199
3208
return {indices, loops[0 ]};
3200
3209
}
3201
3210
3211
+ llvm::SmallVector<mlir::Value> genImplicitLoopBounds (
3212
+ const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &e) {
3213
+ struct Filter : public Fortran ::evaluate::AnyTraverse<
3214
+ Filter, std::optional<llvm::SmallVector<mlir::Value>>> {
3215
+ using Base = Fortran::evaluate::AnyTraverse<
3216
+ Filter, std::optional<llvm::SmallVector<mlir::Value>>>;
3217
+ using Base::operator ();
3218
+
3219
+ Filter (const llvm::SmallVector<mlir::Value> init)
3220
+ : Base(*this ), bounds(init) {}
3221
+
3222
+ std::optional<llvm::SmallVector<mlir::Value>>
3223
+ operator ()(const Fortran::evaluate::ArrayRef &ref) {
3224
+ if ((ref.base ().IsSymbol () || ref.base ().Rank () == 0 ) &&
3225
+ ref.Rank () > 0 && !ref.subscript ().empty ()) {
3226
+ assert (ref.subscript ().size () == bounds.size ());
3227
+ llvm::SmallVector<mlir::Value> result;
3228
+ auto bdIter = bounds.begin ();
3229
+ for (auto ss : ref.subscript ()) {
3230
+ mlir::Value bound = *bdIter++;
3231
+ std::visit (Fortran::common::visitors{
3232
+ [&](const Fortran::evaluate::Triplet &triple) {
3233
+ result.push_back (bound);
3234
+ },
3235
+ [&](const auto &intExpr) {
3236
+ if (intExpr.value ().Rank () > 0 )
3237
+ result.push_back (bound);
3238
+ }},
3239
+ ss.u );
3240
+ }
3241
+ return {result};
3242
+ }
3243
+ return {};
3244
+ }
3245
+
3246
+ llvm::SmallVector<mlir::Value> bounds;
3247
+ };
3248
+
3249
+ auto originalShape = getShape (converter.genType (e));
3250
+ Filter filter (originalShape);
3251
+ if (auto res = filter (e))
3252
+ return *res;
3253
+ return originalShape;
3254
+ }
3255
+
3202
3256
void genMasks () {
3203
3257
auto loc = getLoc ();
3204
3258
// Lower explicit mask expressions, if any.
@@ -3215,8 +3269,8 @@ class ArrayExprLowering {
3215
3269
for (const auto *e : masks->getExprs ())
3216
3270
if (e && !masks->isLowered (e)) {
3217
3271
auto extents = genExplicitExtents ();
3218
- extents. append (explicitImpliedShape. rbegin (),
3219
- explicitImpliedShape .rend ());
3272
+ auto loopBounds = genImplicitLoopBounds (*e);
3273
+ extents. append (loopBounds. rbegin (), loopBounds .rend ());
3220
3274
// Allocate a temporary to cache the mask results.
3221
3275
auto tmpShape = builder.consShape (loc, extents);
3222
3276
auto tmp = createAndLoadSomeArrayTemp (
@@ -3230,7 +3284,8 @@ class ArrayExprLowering {
3230
3284
// Evaluate like any other nested array expression.
3231
3285
ArrayExprLowering ael{converter, masks->stmtContext (), symMap,
3232
3286
ConstituentSemantics::ProjectedCopyInCopyOut};
3233
- ael.lowerArrayAssignment (tmp, *e, indices, explicitImpliedShape);
3287
+ ael.lowerArrayAssignment (tmp, *e, indices,
3288
+ explicitImpliedLoopBounds.getValue ()());
3234
3289
masks->bind (e, tmp.memref (), tmpShape);
3235
3290
builder.setInsertionPointAfter (loop0);
3236
3291
builder.create <fir::ArrayMergeStoreOp>(loc, tmp, loop0.getResult (0 ),
@@ -3289,14 +3344,18 @@ class ArrayExprLowering {
3289
3344
llvm::SmallVector<fir::DoLoopOp> loops;
3290
3345
llvm::SmallVector<mlir::Value> explicitOffsets;
3291
3346
// FORALL loops are outermost.
3292
- if (explicitSpace)
3347
+ if (explicitSpace) {
3293
3348
genExplicitIterSpace (loops, explicitOffsets, innerArg);
3349
+ if (explicitImpliedLoopBounds.hasValue ())
3350
+ loopUppers = explicitImpliedLoopBounds.getValue ()();
3351
+ }
3294
3352
3295
3353
// Now handle the implicit loops.
3296
3354
const auto loopFirst = loops.size ();
3297
3355
const auto loopDepth = loopUppers.size ();
3298
3356
llvm::SmallVector<mlir::Value> ivars;
3299
3357
if (loopDepth > 0 ) {
3358
+ auto *startBlock = builder.getBlock ();
3300
3359
for (auto i : llvm::enumerate (llvm::reverse (loopUppers))) {
3301
3360
if (i.index () > 0 ) {
3302
3361
assert (!loops.empty ());
@@ -3311,8 +3370,11 @@ class ArrayExprLowering {
3311
3370
}
3312
3371
// Add the fir.result for all loops except the innermost one. We must also
3313
3372
// terminate the innermost explicit bounds loop here as well.
3314
- for (std::remove_const_t <decltype (loopFirst)> i =
3315
- loopFirst ? loopFirst - 1 : 0 ;
3373
+ if (loopFirst > 0 ) {
3374
+ builder.setInsertionPointToEnd (startBlock);
3375
+ builder.create <fir::ResultOp>(loc, loops[loopFirst].getResult (0 ));
3376
+ }
3377
+ for (std::remove_const_t <decltype (loopFirst)> i = loopFirst;
3316
3378
i + 1 < loopFirst + loopDepth; ++i) {
3317
3379
builder.setInsertionPointToEnd (loops[i].getBody ());
3318
3380
builder.create <fir::ResultOp>(loc, loops[i + 1 ].getResult (0 ));
@@ -3365,17 +3427,18 @@ class ArrayExprLowering {
3365
3427
// structure is produced.
3366
3428
auto maskExprs = masks->getExprs ();
3367
3429
const auto size = maskExprs.size () - 1 ;
3368
- for (std::remove_const_t <decltype (size)> i = 0 ; i < size; ++i) {
3369
- auto ifOp = builder.create <fir::IfOp>(
3370
- loc, mlir::TypeRange{innerArg.getType ()},
3371
- fir::getBase (
3372
- genCond (masks->getBindingWithShape (maskExprs[i]), iters)),
3373
- /* withElseRegion=*/ true );
3374
- builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
3375
- builder.setInsertionPointToStart (&ifOp.thenRegion ().front ());
3376
- builder.create <fir::ResultOp>(loc, innerArg);
3377
- builder.setInsertionPointToStart (&ifOp.elseRegion ().front ());
3378
- }
3430
+ for (std::remove_const_t <decltype (size)> i = 0 ; i < size; ++i)
3431
+ if (maskExprs[i]) {
3432
+ auto ifOp = builder.create <fir::IfOp>(
3433
+ loc, mlir::TypeRange{innerArg.getType ()},
3434
+ fir::getBase (
3435
+ genCond (masks->getBindingWithShape (maskExprs[i]), iters)),
3436
+ /* withElseRegion=*/ true );
3437
+ builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
3438
+ builder.setInsertionPointToStart (&ifOp.thenRegion ().front ());
3439
+ builder.create <fir::ResultOp>(loc, innerArg);
3440
+ builder.setInsertionPointToStart (&ifOp.elseRegion ().front ());
3441
+ }
3379
3442
3380
3443
// The last condition is either non-negated or unconditionally negated.
3381
3444
if (maskExprs[size]) {
@@ -4285,11 +4348,8 @@ class ArrayExprLowering {
4285
4348
template <typename A>
4286
4349
std::pair<CC, mlir::Type> raiseRankedBase (const A &x) {
4287
4350
auto result = raiseBase (x);
4288
- if (isProjectedCopyInCopyOut ()) {
4289
- auto optShape = Fortran::evaluate::GetShape (x);
4290
- assert (optShape.has_value ());
4291
- convertFEShape (*optShape, explicitImpliedShape);
4292
- }
4351
+ if (isProjectedCopyInCopyOut ())
4352
+ explicitImpliedLoopBounds = [=]() { return getShape (x); };
4293
4353
return result;
4294
4354
}
4295
4355
template <typename A>
@@ -4320,11 +4380,8 @@ class ArrayExprLowering {
4320
4380
std::pair<CC, mlir::Type> raiseRankedComponent (llvm::Optional<CC> cc,
4321
4381
const A &x, mlir::Type inTy) {
4322
4382
auto result = raiseComponent (cc, x, inTy, false );
4323
- if (isProjectedCopyInCopyOut ()) {
4324
- auto optShape = Fortran::evaluate::GetShape (x);
4325
- assert (optShape.has_value ());
4326
- convertFEShape (*optShape, explicitImpliedShape);
4327
- }
4383
+ if (isProjectedCopyInCopyOut ())
4384
+ explicitImpliedLoopBounds = [=]() { return getShape (x); };
4328
4385
return result;
4329
4386
}
4330
4387
@@ -4393,8 +4450,7 @@ class ArrayExprLowering {
4393
4450
auto &sym = base.GetFirstSymbol ();
4394
4451
if (x.Rank () > 0 || accessUsesControlVariable ()) {
4395
4452
auto [fopt2, ty2] = raiseBase (sym);
4396
- return RaiseRT{fopt2, fir::unwrapSequenceType (ty2), false ,
4397
- x.Rank () > 0 };
4453
+ return RaiseRT{fopt2, ty2, false , x.Rank () > 0 };
4398
4454
}
4399
4455
return RaiseRT{llvm::None, mlir::Type{}, false , false };
4400
4456
}
@@ -4411,11 +4467,53 @@ class ArrayExprLowering {
4411
4467
}(),
4412
4468
x);
4413
4469
}
4470
+ static mlir::Type unwrapBoxEleTy (mlir::Type ty) {
4471
+ if (auto boxTy = ty.dyn_cast <fir::BoxType>()) {
4472
+ ty = boxTy.getEleTy ();
4473
+ if (auto refTy = fir::dyn_cast_ptrEleTy (ty))
4474
+ ty = refTy;
4475
+ }
4476
+ return ty;
4477
+ }
4478
+ llvm::SmallVector<mlir::Value> getShape (mlir::Type ty) {
4479
+ llvm::SmallVector<mlir::Value> result;
4480
+ ty = unwrapBoxEleTy (ty);
4481
+ auto loc = getLoc ();
4482
+ auto idxTy = builder.getIndexType ();
4483
+ for (auto extent : ty.cast <fir::SequenceType>().getShape ()) {
4484
+ auto v = extent == fir::SequenceType::getUnknownExtent ()
4485
+ ? builder.create <fir::UndefOp>(loc, idxTy).getResult ()
4486
+ : builder.createIntegerConstant (loc, idxTy, extent);
4487
+ result.push_back (v);
4488
+ }
4489
+ return result;
4490
+ }
4491
+ llvm::SmallVector<mlir::Value>
4492
+ getShape (const Fortran::semantics::SymbolRef &x) {
4493
+ if (x.get ().Rank () == 0 )
4494
+ return {};
4495
+ return getFrontEndShape (x);
4496
+ }
4497
+ template <typename A>
4498
+ llvm::SmallVector<mlir::Value> getShape (const A &x) {
4499
+ if (x.Rank () == 0 )
4500
+ return {};
4501
+ return getFrontEndShape (x);
4502
+ }
4503
+ template <typename A>
4504
+ llvm::SmallVector<mlir::Value> getFrontEndShape (const A &x) {
4505
+ if (auto optShape = Fortran::evaluate::GetShape (x)) {
4506
+ llvm::SmallVector<mlir::Value> result;
4507
+ convertFEShape (*optShape, result);
4508
+ return result;
4509
+ }
4510
+ return {};
4511
+ }
4414
4512
RaiseRT raiseSubscript (const RaiseRT &tup,
4415
4513
const Fortran::evaluate::ArrayRef &x) {
4416
4514
auto fopt = std::get<llvm::Optional<CC>>(tup);
4417
4515
if (fopt.hasValue ()) {
4418
- auto ty = std::get<mlir::Type>(tup);
4516
+ auto arrTy = std::get<mlir::Type>(tup);
4419
4517
auto prevRanked = std::get<2 >(tup);
4420
4518
auto ranked = std::get<3 >(tup);
4421
4519
auto lambda = fopt.getValue ();
@@ -4429,30 +4527,64 @@ class ArrayExprLowering {
4429
4527
// from the explicit space, then those dimensions should not be
4430
4528
// considered as contributing to the implied part of the iteration
4431
4529
// space.
4432
- if (explicitImpliedShape.empty ()) {
4433
- assert (destination && " destination must be set" );
4434
- auto feShape = getShape (destination);
4530
+ if (!explicitImpliedLoopBounds.hasValue ()) {
4435
4531
if (subs.empty ()) {
4436
- explicitImpliedShape. assign (feShape) ;
4532
+ explicitImpliedLoopBounds = [=]() { return getShape (x); } ;
4437
4533
} else {
4438
- unsigned ii = 0 ;
4534
+ auto desShape = getShape (x) ;
4439
4535
unsigned vi = 0 ;
4440
- vectorCoor.resize (feShape .size ());
4536
+ vectorCoor.resize (desShape .size ());
4441
4537
// Filter out subscripts that are scalar expressions. If it is a
4442
- // scalar expression it is either loop-invariant or a function of
4443
- // the explicit loop control variables.
4444
- for (const auto &ss : subs)
4538
+ // scalar expression it is either loop-invariant or a function
4539
+ // of the explicit loop control variables.
4540
+ for (const auto &ss : subs) {
4445
4541
if (auto *intExpr = std::get_if<
4446
- Fortran::evaluate::IndirectSubscriptIntegerExpr>(&ss.u )) {
4447
- if (intExpr->value ().Rank () > 0 ) {
4448
- explicitImpliedShape.push_back (feShape[ii++]);
4542
+ Fortran::evaluate::IndirectSubscriptIntegerExpr>(&ss.u ))
4543
+ if (intExpr->value ().Rank () > 0 )
4449
4544
vectorCoor[vi++] = genarr (intExpr->value ());
4450
- }
4451
- } else {
4452
- // This is a triple which may be using an explicit control
4453
- // variable.
4454
- explicitImpliedShape.push_back (feShape[ii++]);
4455
- }
4545
+ }
4546
+ explicitImpliedLoopBounds = [=]() {
4547
+ llvm::SmallVector<mlir::Value> result;
4548
+ unsigned ii = 0 ;
4549
+ for (const auto &ss : subs)
4550
+ std::visit (
4551
+ Fortran::common::visitors{
4552
+ [&](const Fortran::evaluate::
4553
+ IndirectSubscriptIntegerExpr &intExpr) {
4554
+ if (intExpr.value ().Rank () > 0 )
4555
+ result.push_back (builder.createConvert (
4556
+ loc, idxTy, desShape[ii++]));
4557
+ },
4558
+ [&](const Fortran::evaluate::Triplet &triple) {
4559
+ // This is a triple which may be using an
4560
+ // explicit control variable.
4561
+ auto ou = triple.upper ();
4562
+ auto up = builder.createConvert (
4563
+ loc, idxTy,
4564
+ ou.has_value () ? fir::getBase (asScalar (*ou))
4565
+ : desShape[ii]);
4566
+ auto ol = triple.lower ();
4567
+ auto lo =
4568
+ ol.has_value ()
4569
+ ? builder.createConvert (
4570
+ loc, idxTy, fir::getBase (asScalar (*ol)))
4571
+ : builder.createIntegerConstant (loc, idxTy,
4572
+ 1 );
4573
+ auto step = builder.createConvert (
4574
+ loc, idxTy,
4575
+ fir::getBase (asScalar (triple.stride ())));
4576
+ auto diff = builder.create <mlir::SubIOp>(loc, up, lo);
4577
+ auto sum =
4578
+ builder.create <mlir::AddIOp>(loc, diff, step);
4579
+ mlir::Value count =
4580
+ builder.create <mlir::SignedDivIOp>(loc, sum,
4581
+ step);
4582
+ result.push_back (count);
4583
+ ii++;
4584
+ }},
4585
+ ss.u );
4586
+ return result;
4587
+ };
4456
4588
}
4457
4589
}
4458
4590
}
@@ -4519,6 +4651,7 @@ class ArrayExprLowering {
4519
4651
}
4520
4652
return newIters;
4521
4653
};
4654
+ auto ty = fir::unwrapSequenceType (unwrapBoxEleTy (arrTy));
4522
4655
return RaiseRT{[=](IterSpace iters) { return lambda (pc (iters)); }, ty,
4523
4656
prevRanked, ranked};
4524
4657
}
@@ -4605,7 +4738,9 @@ class ArrayExprLowering {
4605
4738
4606
4739
static mlir::Type adjustedArraySubtype (mlir::Type ty,
4607
4740
mlir::ValueRange indices) {
4608
- return adjustedArrayElementType (fir::applyPathToType (ty, indices));
4741
+ auto pathTy = fir::applyPathToType (ty, indices);
4742
+ assert (pathTy && " indices failed to apply to type" );
4743
+ return adjustedArrayElementType (pathTy);
4609
4744
}
4610
4745
4611
4746
// / Build an ExtendedValue from a fir.array<?x...?xT> without actually
@@ -5486,7 +5621,7 @@ class ArrayExprLowering {
5486
5621
// / Even in an explicitly defined iteration space, one can have an
5487
5622
// / assignment with rank > 0 and thus an implied shape on a component in the
5488
5623
// / path.
5489
- llvm::SmallVector<mlir::Value> explicitImpliedShape ;
5624
+ llvm::Optional<LBC> explicitImpliedLoopBounds ;
5490
5625
Fortran::lower::ImplicitIterSpace *masks = nullptr ;
5491
5626
ConstituentSemantics semant = ConstituentSemantics::RefTransparent;
5492
5627
bool inSlice = false ;
0 commit comments