Skip to content

Commit e31ba91

Browse files
committed
[flang] Take result length into account in ApplyElementwise folding
ApplyElementwise on character operation was always creating a result ArrayConstructor with the length of the left operand. This is not correct for concatenation and SetLength operations. Compute and thread the length to the spot creating the ArrayConstructor so that the length is correct for those character operations. Differential Revision: https://reviews.llvm.org/D108711
1 parent 880871c commit e31ba91

File tree

2 files changed

+60
-14
lines changed

2 files changed

+60
-14
lines changed

flang/lib/Evaluate/fold-implementation.h

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -898,12 +898,24 @@ Expr<RESULT> MapOperation(FoldingContext &context,
898898
context, std::move(result), AsConstantExtents(context, shape));
899899
}
900900

901+
template <typename RESULT, typename A>
902+
ArrayConstructor<RESULT> ArrayConstructorFromMold(
903+
const A &prototype, std::optional<Expr<SubscriptInteger>> &&length) {
904+
if constexpr (RESULT::category == TypeCategory::Character) {
905+
return ArrayConstructor<RESULT>{
906+
std::move(length.value()), ArrayConstructorValues<RESULT>{}};
907+
} else {
908+
return ArrayConstructor<RESULT>{prototype};
909+
}
910+
}
911+
901912
// array * array case
902913
template <typename RESULT, typename LEFT, typename RIGHT>
903914
Expr<RESULT> MapOperation(FoldingContext &context,
904915
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
905-
const Shape &shape, Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
906-
ArrayConstructor<RESULT> result{leftValues};
916+
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
917+
Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
918+
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
907919
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
908920
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
909921
std::visit(
@@ -942,9 +954,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
942954
template <typename RESULT, typename LEFT, typename RIGHT>
943955
Expr<RESULT> MapOperation(FoldingContext &context,
944956
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
945-
const Shape &shape, Expr<LEFT> &&leftValues,
946-
const Expr<RIGHT> &rightScalar) {
947-
ArrayConstructor<RESULT> result{leftValues};
957+
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
958+
Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar) {
959+
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
948960
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
949961
for (auto &leftValue : leftArrConst) {
950962
auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
@@ -959,9 +971,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
959971
template <typename RESULT, typename LEFT, typename RIGHT>
960972
Expr<RESULT> MapOperation(FoldingContext &context,
961973
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
962-
const Shape &shape, const Expr<LEFT> &leftScalar,
963-
Expr<RIGHT> &&rightValues) {
964-
ArrayConstructor<RESULT> result{leftScalar};
974+
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
975+
const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues) {
976+
auto result{ArrayConstructorFromMold<RESULT>(leftScalar, std::move(length))};
965977
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
966978
std::visit(
967979
[&](auto &&kindExpr) {
@@ -987,6 +999,15 @@ Expr<RESULT> MapOperation(FoldingContext &context,
987999
context, std::move(result), AsConstantExtents(context, shape));
9881000
}
9891001

1002+
template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1003+
std::optional<Expr<SubscriptInteger>> ComputeResultLength(
1004+
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
1005+
if constexpr (RESULT::category == TypeCategory::Character) {
1006+
return Expr<RESULT>{operation.derived()}.LEN();
1007+
}
1008+
return std::nullopt;
1009+
}
1010+
9901011
// ApplyElementwise() recursively folds the operand expression(s) of an
9911012
// operation, then attempts to apply the operation to the (corresponding)
9921013
// scalar element(s) of those operands. Returns std::nullopt for scalars
@@ -1024,6 +1045,7 @@ auto ApplyElementwise(FoldingContext &context,
10241045
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
10251046
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
10261047
-> std::optional<Expr<RESULT>> {
1048+
auto resultLength{ComputeResultLength(operation)};
10271049
auto &leftExpr{operation.left()};
10281050
leftExpr = Fold(context, std::move(leftExpr));
10291051
auto &rightExpr{operation.right()};
@@ -1038,25 +1060,26 @@ auto ApplyElementwise(FoldingContext &context,
10381060
CheckConformanceFlags::EitherScalarExpandable)
10391061
.value_or(false /*fail if not known now to conform*/)) {
10401062
return MapOperation(context, std::move(f), *leftShape,
1041-
std::move(*left), std::move(*right));
1063+
std::move(resultLength), std::move(*left),
1064+
std::move(*right));
10421065
} else {
10431066
return std::nullopt;
10441067
}
10451068
return MapOperation(context, std::move(f), *leftShape,
1046-
std::move(*left), std::move(*right));
1069+
std::move(resultLength), std::move(*left), std::move(*right));
10471070
}
10481071
}
10491072
} else if (IsExpandableScalar(rightExpr)) {
1050-
return MapOperation(
1051-
context, std::move(f), *leftShape, std::move(*left), rightExpr);
1073+
return MapOperation(context, std::move(f), *leftShape,
1074+
std::move(resultLength), std::move(*left), rightExpr);
10521075
}
10531076
}
10541077
}
10551078
} else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) {
10561079
if (std::optional<Shape> shape{GetShape(context, rightExpr)}) {
10571080
if (auto right{AsFlatArrayConstructor(rightExpr)}) {
1058-
return MapOperation(
1059-
context, std::move(f), *shape, leftExpr, std::move(*right));
1081+
return MapOperation(context, std::move(f), *shape,
1082+
std::move(resultLength), leftExpr, std::move(*right));
10601083
}
10611084
}
10621085
}

flang/test/Evaluate/folding22.f90

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
! RUN: %S/test_folding.sh %s %t %flang_fc1
2+
! REQUIRES: shell
3+
4+
! Test character concatenation folding
5+
6+
logical, parameter :: test_scalar_scalar = ('ab' // 'cde').eq.('abcde')
7+
8+
character(2), parameter :: scalar_array(2) = ['1','2'] // 'a'
9+
logical, parameter :: test_scalar_array = all(scalar_array.eq.(['1a', '2a']))
10+
11+
character(2), parameter :: array_scalar(2) = '1' // ['a', 'b']
12+
logical, parameter :: test_array_scalar = all(array_scalar.eq.(['1a', '1b']))
13+
14+
character(2), parameter :: array_array(2) = ['1','2'] // ['a', 'b']
15+
logical, parameter :: test_array_array = all(array_array.eq.(['1a', '2b']))
16+
17+
18+
character(1), parameter :: input(2) = ['x', 'y']
19+
character(*), parameter :: zero_sized(*) = input(2:1:1) // 'abcde'
20+
logical, parameter :: test_zero_sized = len(zero_sized).eq.6
21+
22+
end
23+

0 commit comments

Comments
 (0)