Skip to content

Commit 38afb0f

Browse files
authored
Merge pull request #1016 from flang-compiler/jpr-cherry-pick-fma1-fix
Cherry-pick [flang] Take result length into account in ApplyElementwise folding
2 parents 880871c + e31ba91 commit 38afb0f

File tree

2 files changed

+60
-14
lines changed

2 files changed

+60
-14
lines changed

flang/lib/Evaluate/fold-implementation.h

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -898,12 +898,24 @@ Expr<RESULT> MapOperation(FoldingContext &context,
898898
context, std::move(result), AsConstantExtents(context, shape));
899899
}
900900

901+
template <typename RESULT, typename A>
902+
ArrayConstructor<RESULT> ArrayConstructorFromMold(
903+
const A &prototype, std::optional<Expr<SubscriptInteger>> &&length) {
904+
if constexpr (RESULT::category == TypeCategory::Character) {
905+
return ArrayConstructor<RESULT>{
906+
std::move(length.value()), ArrayConstructorValues<RESULT>{}};
907+
} else {
908+
return ArrayConstructor<RESULT>{prototype};
909+
}
910+
}
911+
901912
// array * array case
902913
template <typename RESULT, typename LEFT, typename RIGHT>
903914
Expr<RESULT> MapOperation(FoldingContext &context,
904915
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
905-
const Shape &shape, Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
906-
ArrayConstructor<RESULT> result{leftValues};
916+
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
917+
Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
918+
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
907919
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
908920
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
909921
std::visit(
@@ -942,9 +954,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
942954
template <typename RESULT, typename LEFT, typename RIGHT>
943955
Expr<RESULT> MapOperation(FoldingContext &context,
944956
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
945-
const Shape &shape, Expr<LEFT> &&leftValues,
946-
const Expr<RIGHT> &rightScalar) {
947-
ArrayConstructor<RESULT> result{leftValues};
957+
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
958+
Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar) {
959+
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
948960
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
949961
for (auto &leftValue : leftArrConst) {
950962
auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
@@ -959,9 +971,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
959971
template <typename RESULT, typename LEFT, typename RIGHT>
960972
Expr<RESULT> MapOperation(FoldingContext &context,
961973
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
962-
const Shape &shape, const Expr<LEFT> &leftScalar,
963-
Expr<RIGHT> &&rightValues) {
964-
ArrayConstructor<RESULT> result{leftScalar};
974+
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
975+
const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues) {
976+
auto result{ArrayConstructorFromMold<RESULT>(leftScalar, std::move(length))};
965977
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
966978
std::visit(
967979
[&](auto &&kindExpr) {
@@ -987,6 +999,15 @@ Expr<RESULT> MapOperation(FoldingContext &context,
987999
context, std::move(result), AsConstantExtents(context, shape));
9881000
}
9891001

1002+
template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1003+
std::optional<Expr<SubscriptInteger>> ComputeResultLength(
1004+
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
1005+
if constexpr (RESULT::category == TypeCategory::Character) {
1006+
return Expr<RESULT>{operation.derived()}.LEN();
1007+
}
1008+
return std::nullopt;
1009+
}
1010+
9901011
// ApplyElementwise() recursively folds the operand expression(s) of an
9911012
// operation, then attempts to apply the operation to the (corresponding)
9921013
// scalar element(s) of those operands. Returns std::nullopt for scalars
@@ -1024,6 +1045,7 @@ auto ApplyElementwise(FoldingContext &context,
10241045
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
10251046
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
10261047
-> std::optional<Expr<RESULT>> {
1048+
auto resultLength{ComputeResultLength(operation)};
10271049
auto &leftExpr{operation.left()};
10281050
leftExpr = Fold(context, std::move(leftExpr));
10291051
auto &rightExpr{operation.right()};
@@ -1038,25 +1060,26 @@ auto ApplyElementwise(FoldingContext &context,
10381060
CheckConformanceFlags::EitherScalarExpandable)
10391061
.value_or(false /*fail if not known now to conform*/)) {
10401062
return MapOperation(context, std::move(f), *leftShape,
1041-
std::move(*left), std::move(*right));
1063+
std::move(resultLength), std::move(*left),
1064+
std::move(*right));
10421065
} else {
10431066
return std::nullopt;
10441067
}
10451068
return MapOperation(context, std::move(f), *leftShape,
1046-
std::move(*left), std::move(*right));
1069+
std::move(resultLength), std::move(*left), std::move(*right));
10471070
}
10481071
}
10491072
} else if (IsExpandableScalar(rightExpr)) {
1050-
return MapOperation(
1051-
context, std::move(f), *leftShape, std::move(*left), rightExpr);
1073+
return MapOperation(context, std::move(f), *leftShape,
1074+
std::move(resultLength), std::move(*left), rightExpr);
10521075
}
10531076
}
10541077
}
10551078
} else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) {
10561079
if (std::optional<Shape> shape{GetShape(context, rightExpr)}) {
10571080
if (auto right{AsFlatArrayConstructor(rightExpr)}) {
1058-
return MapOperation(
1059-
context, std::move(f), *shape, leftExpr, std::move(*right));
1081+
return MapOperation(context, std::move(f), *shape,
1082+
std::move(resultLength), leftExpr, std::move(*right));
10601083
}
10611084
}
10621085
}

flang/test/Evaluate/folding22.f90

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
! RUN: %S/test_folding.sh %s %t %flang_fc1
2+
! REQUIRES: shell
3+
4+
! Test character concatenation folding
5+
6+
logical, parameter :: test_scalar_scalar = ('ab' // 'cde').eq.('abcde')
7+
8+
character(2), parameter :: scalar_array(2) = ['1','2'] // 'a'
9+
logical, parameter :: test_scalar_array = all(scalar_array.eq.(['1a', '2a']))
10+
11+
character(2), parameter :: array_scalar(2) = '1' // ['a', 'b']
12+
logical, parameter :: test_array_scalar = all(array_scalar.eq.(['1a', '1b']))
13+
14+
character(2), parameter :: array_array(2) = ['1','2'] // ['a', 'b']
15+
logical, parameter :: test_array_array = all(array_array.eq.(['1a', '2b']))
16+
17+
18+
character(1), parameter :: input(2) = ['x', 'y']
19+
character(*), parameter :: zero_sized(*) = input(2:1:1) // 'abcde'
20+
logical, parameter :: test_zero_sized = len(zero_sized).eq.6
21+
22+
end
23+

0 commit comments

Comments
 (0)