Skip to content

Commit 67262fd

Browse files
klauslerjeanPerier
authored andcommitted
[flang] Fold EOSHIFT
Implement constant folding for the transformational intrinsic function EOSHIFT. Differential Revision: https://reviews.llvm.org/D108941
1 parent 56c6432 commit 67262fd

File tree

4 files changed

+131
-4
lines changed

4 files changed

+131
-4
lines changed

flang/lib/Evaluate/fold-implementation.h

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ template <typename T> class Folder {
6262
Constant<T> *Folding(std::optional<ActualArgument> &);
6363

6464
Expr<T> CSHIFT(FunctionRef<T> &&);
65+
Expr<T> EOSHIFT(FunctionRef<T> &&);
6566
Expr<T> RESHAPE(FunctionRef<T> &&);
6667

6768
private:
@@ -619,6 +620,112 @@ template <typename T> Expr<T> Folder<T>::CSHIFT(FunctionRef<T> &&funcRef) {
619620
return MakeInvalidIntrinsic(std::move(funcRef));
620621
}
621622

623+
template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
624+
auto args{funcRef.arguments()};
625+
CHECK(args.size() == 4);
626+
const auto *array{UnwrapConstantValue<T>(args[0])};
627+
const auto *shiftExpr{UnwrapExpr<Expr<SomeInteger>>(args[1])};
628+
auto dim{GetInt64ArgOr(args[3], 1)};
629+
if (!array || !shiftExpr || !dim) {
630+
return Expr<T>{std::move(funcRef)};
631+
}
632+
// Apply type conversions to the shift= and boundary= arguments.
633+
auto convertedShift{Fold(context_,
634+
ConvertToType<SubscriptInteger>(Expr<SomeInteger>{*shiftExpr}))};
635+
const auto *shift{UnwrapConstantValue<SubscriptInteger>(convertedShift)};
636+
if (!shift) {
637+
return Expr<T>{std::move(funcRef)};
638+
}
639+
const Constant<T> *boundary{nullptr};
640+
std::optional<Expr<SomeType>> convertedBoundary;
641+
if (const auto *boundaryExpr{UnwrapExpr<Expr<SomeType>>(args[2])}) {
642+
convertedBoundary = Fold(context_,
643+
ConvertToType(array->GetType(), Expr<SomeType>{*boundaryExpr}));
644+
boundary = UnwrapExpr<Constant<T>>(convertedBoundary);
645+
if (!boundary) {
646+
return Expr<T>{std::move(funcRef)};
647+
}
648+
}
649+
// Arguments are constant
650+
if (*dim < 1 || *dim > array->Rank()) {
651+
context_.messages().Say(
652+
"Invalid 'dim=' argument (%jd) in EOSHIFT"_err_en_US,
653+
static_cast<std::intmax_t>(*dim));
654+
} else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) {
655+
// message already emitted from intrinsic look-up
656+
} else {
657+
int rank{array->Rank()};
658+
int zbDim{static_cast<int>(*dim) - 1};
659+
bool ok{true};
660+
if (shift->Rank() > 0) {
661+
int k{0};
662+
for (int j{0}; j < rank; ++j) {
663+
if (j != zbDim) {
664+
if (array->shape()[j] != shift->shape()[k]) {
665+
context_.messages().Say(
666+
"Invalid 'shift=' argument in EOSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
667+
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
668+
static_cast<std::intmax_t>(array->shape()[j]));
669+
ok = false;
670+
}
671+
if (boundary && array->shape()[j] != boundary->shape()[k]) {
672+
context_.messages().Say(
673+
"Invalid 'boundary=' argument in EOSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
674+
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
675+
static_cast<std::intmax_t>(array->shape()[j]));
676+
ok = false;
677+
}
678+
++k;
679+
}
680+
}
681+
}
682+
if (ok) {
683+
std::vector<Scalar<T>> resultElements;
684+
ConstantSubscripts arrayAt{array->lbounds()};
685+
ConstantSubscript dimLB{arrayAt[zbDim]};
686+
ConstantSubscript dimExtent{array->shape()[zbDim]};
687+
ConstantSubscripts shiftAt{shift->lbounds()};
688+
ConstantSubscripts boundaryAt;
689+
if (boundary) {
690+
boundaryAt = boundary->lbounds();
691+
}
692+
for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) {
693+
ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
694+
for (ConstantSubscript j{0}; j < dimExtent; ++j) {
695+
ConstantSubscript zbAt{shiftCount + j};
696+
if (zbAt >= 0 && zbAt < dimExtent) {
697+
arrayAt[zbDim] = dimLB + zbAt;
698+
resultElements.push_back(array->At(arrayAt));
699+
} else if (boundary) {
700+
resultElements.push_back(boundary->At(boundaryAt));
701+
} else if constexpr (T::category == TypeCategory::Integer ||
702+
T::category == TypeCategory::Real ||
703+
T::category == TypeCategory::Complex ||
704+
T::category == TypeCategory::Logical) {
705+
resultElements.emplace_back();
706+
} else if constexpr (T::category == TypeCategory::Character) {
707+
auto len{static_cast<std::size_t>(array->LEN())};
708+
typename Scalar<T>::value_type space{' '};
709+
resultElements.emplace_back(len, space);
710+
} else {
711+
DIE("no derived type boundary");
712+
}
713+
}
714+
arrayAt[zbDim] = dimLB + dimExtent - 1;
715+
array->IncrementSubscripts(arrayAt);
716+
shift->IncrementSubscripts(shiftAt);
717+
if (boundary) {
718+
boundary->IncrementSubscripts(boundaryAt);
719+
}
720+
}
721+
return Expr<T>{PackageConstant<T>(
722+
std::move(resultElements), *array, array->shape())};
723+
}
724+
}
725+
// Invalid, prevent re-folding
726+
return MakeInvalidIntrinsic(std::move(funcRef));
727+
}
728+
622729
template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
623730
auto args{funcRef.arguments()};
624731
CHECK(args.size() == 4);
@@ -754,6 +861,8 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
754861
const std::string name{intrinsic->name};
755862
if (name == "cshift") {
756863
return Folder<T>{context}.CSHIFT(std::move(funcRef));
864+
} else if (name == "eoshift") {
865+
return Folder<T>{context}.EOSHIFT(std::move(funcRef));
757866
} else if (name == "reshape") {
758867
return Folder<T>{context}.RESHAPE(std::move(funcRef));
759868
}

flang/lib/Evaluate/fold-logical.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
125125
name == "__builtin_ieee_support_underflow_control") {
126126
return Expr<T>{true};
127127
}
128-
// TODO: btest, dot_product, eoshift, is_iostat_end,
128+
// TODO: btest, dot_product, is_iostat_end,
129129
// is_iostat_eor, lge, lgt, lle, llt, logical, matmul, out_of_range,
130130
// parity, transfer
131131
return Expr<T>{std::move(funcRef)};

flang/lib/Evaluate/intrinsics.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,15 +385,17 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
385385
{"eoshift",
386386
{{"array", SameIntrinsic, Rank::array},
387387
{"shift", AnyInt, Rank::dimRemovedOrScalar},
388-
{"boundary", SameIntrinsic, Rank::dimReduced,
388+
{"boundary", SameIntrinsic, Rank::dimRemovedOrScalar,
389389
Optionality::optional},
390390
OptionalDIM},
391391
SameIntrinsic, Rank::conformable,
392392
IntrinsicClass::transformationalFunction},
393393
{"eoshift",
394394
{{"array", SameDerivedType, Rank::array},
395-
{"shift", AnyInt, Rank::dimReduced},
396-
{"boundary", SameDerivedType, Rank::dimReduced}, OptionalDIM},
395+
{"shift", AnyInt, Rank::dimRemovedOrScalar},
396+
// BOUNDARY= is not optional for derived types
397+
{"boundary", SameDerivedType, Rank::dimRemovedOrScalar},
398+
OptionalDIM},
397399
SameDerivedType, Rank::conformable,
398400
IntrinsicClass::transformationalFunction},
399401
{"epsilon", {{"x", SameReal, Rank::anyOrAssumedRank}}, SameReal,

flang/test/Evaluate/folding23.f90

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
! RUN: %S/test_folding.sh %s %t %flang_fc1
2+
! REQUIRES: shell
3+
! Tests folding of EOSHIFT (valid cases)
4+
module m
5+
integer, parameter :: arr(2,3) = reshape([1, 2, 3, 4, 5, 6], shape(arr))
6+
logical, parameter :: test_sanity = all([arr] == [1, 2, 3, 4, 5, 6])
7+
logical, parameter :: test_eoshift_0 = all(eoshift([1, 2, 3], 0) == [1, 2, 3])
8+
logical, parameter :: test_eoshift_1 = all(eoshift([1, 2, 3], 1) == [2, 3, 0])
9+
logical, parameter :: test_eoshift_2 = all(eoshift([1, 2, 3], -1) == [0, 1, 2])
10+
logical, parameter :: test_eoshift_3 = all(eoshift([1., 2., 3.], 1) == [2., 3., 0.])
11+
logical, parameter :: test_eoshift_4 = all(eoshift(['ab', 'cd', 'ef'], -1, 'x') == ['x ', 'ab', 'cd'])
12+
logical, parameter :: test_eoshift_5 = all([eoshift(arr, 1, dim=1)] == [2, 0, 4, 0, 6, 0])
13+
logical, parameter :: test_eoshift_6 = all([eoshift(arr, 1, dim=2)] == [3, 5, 0, 4, 6, 0])
14+
logical, parameter :: test_eoshift_7 = all([eoshift(arr, [1, -1, 0])] == [2, 0, 0, 3, 5, 6])
15+
logical, parameter :: test_eoshift_8 = all([eoshift(arr, [1, -1], dim=2)] == [3, 5, 0, 0, 2, 4])
16+
end module

0 commit comments

Comments
 (0)