Skip to content

Commit abe21a1

Browse files
klauslerjeanPerier
authored andcommitted
[flang] Fold COUNT()
Complete folding of the intrinsic reduction function COUNT() for all cases, including partial reductions with DIM= arguments. Differential Revision: https://reviews.llvm.org/D109911
1 parent 8ac9eac commit abe21a1

File tree

7 files changed

+90
-38
lines changed

7 files changed

+90
-38
lines changed

flang/lib/Evaluate/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_flang_library(FortranEvaluate
3434
fold-integer.cpp
3535
fold-logical.cpp
3636
fold-real.cpp
37+
fold-reduction.cpp
3738
formatting.cpp
3839
host.cpp
3940
initial-image.cpp

flang/lib/Evaluate/fold-implementation.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
492492
// Build and return constant result
493493
if constexpr (TR::category == TypeCategory::Character) {
494494
auto len{static_cast<ConstantSubscript>(
495-
results.size() ? results[0].length() : 0)};
495+
results.empty() ? 0 : results[0].length())};
496496
return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
497497
} else {
498498
return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
@@ -944,7 +944,7 @@ Expr<T> FoldMINorMAX(
944944
if (constantArgs.size() != funcRef.arguments().size()) {
945945
return Expr<T>(std::move(funcRef));
946946
}
947-
CHECK(constantArgs.size() > 0);
947+
CHECK(!constantArgs.empty());
948948
Expr<T> result{std::move(*constantArgs[0])};
949949
for (std::size_t i{1}; i < constantArgs.size(); ++i) {
950950
Extremum<T> extremum{order, result, Expr<T>{std::move(*constantArgs[i])}};
@@ -1075,7 +1075,7 @@ template <typename T> class ArrayConstructorFolder {
10751075
Expr<T> folded{Fold(context_, common::Clone(expr.value()))};
10761076
if (const auto *c{UnwrapConstantValue<T>(folded)}) {
10771077
// Copy elements in Fortran array element order
1078-
if (c->size() > 0) {
1078+
if (!c->empty()) {
10791079
ConstantSubscripts index{c->lbounds()};
10801080
do {
10811081
elements_.emplace_back(c->At(index));
@@ -1156,7 +1156,7 @@ template <typename T>
11561156
std::optional<Expr<T>> AsFlatArrayConstructor(const Expr<T> &expr) {
11571157
if (const auto *c{UnwrapConstantValue<T>(expr)}) {
11581158
ArrayConstructor<T> result{expr};
1159-
if (c->size() > 0) {
1159+
if (!c->empty()) {
11601160
ConstantSubscripts at{c->lbounds()};
11611161
do {
11621162
result.Push(Expr<T>{Constant<T>{c->At(at)}});

flang/lib/Evaluate/fold-integer.cpp

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,21 +174,47 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
174174
return Expr<T>{std::move(funcRef)};
175175
}
176176

177+
// COUNT()
178+
template <typename T>
179+
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
180+
static_assert(T::category == TypeCategory::Integer);
181+
ActualArguments &arg{ref.arguments()};
182+
if (const Constant<LogicalResult> *mask{arg.empty()
183+
? nullptr
184+
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
185+
std::optional<ConstantSubscript> dim;
186+
if (arg.size() > 1 && arg[1]) {
187+
dim = CheckDIM(context, arg[1], mask->Rank());
188+
if (!dim) {
189+
mask = nullptr;
190+
}
191+
}
192+
if (mask) {
193+
auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
194+
if (mask->At(at).IsTrue()) {
195+
element = element.AddSigned(Scalar<T>{1}).value;
196+
}
197+
}};
198+
return Expr<T>{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
199+
}
200+
}
201+
return Expr<T>{std::move(ref)};
202+
}
203+
177204
// for IALL, IANY, & IPARITY
178205
template <typename T>
179206
static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
180207
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
181208
Scalar<T> identity) {
182209
static_assert(T::category == TypeCategory::Integer);
183-
using Element = Scalar<T>;
184210
std::optional<ConstantSubscript> dim;
185211
if (std::optional<Constant<T>> array{
186212
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
187213
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
188-
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
214+
auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
189215
element = (element.*operation)(array->At(at));
190216
}};
191-
return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
217+
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
192218
}
193219
return Expr<T>{std::move(ref)};
194220
}
@@ -237,17 +263,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
237263
cx->u);
238264
}
239265
} else if (name == "count") {
240-
if (!args[1]) { // TODO: COUNT(x,DIM=d)
241-
if (const auto *constant{UnwrapConstantValue<LogicalResult>(args[0])}) {
242-
std::int64_t result{0};
243-
for (const auto &element : constant->values()) {
244-
if (element.IsTrue()) {
245-
++result;
246-
}
247-
}
248-
return Expr<T>{result};
249-
}
250-
}
266+
return FoldCount<T>(context, std::move(funcRef));
251267
} else if (name == "digits") {
252268
if (const auto *cx{UnwrapExpr<Expr<SomeInteger>>(args[0])}) {
253269
return Expr<T>{std::visit(

flang/lib/Evaluate/fold-logical.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ static Expr<T> FoldAllAny(FoldingContext &context, FunctionRef<T> &&ref,
2626
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
2727
element = (element.*operation)(array->At(at));
2828
}};
29-
return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
29+
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
3030
}
3131
return Expr<T>{std::move(ref)};
3232
}

flang/lib/Evaluate/fold-reduction.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===-- lib/Evaluate/fold-reduction.cpp -----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "fold-reduction.h"
10+
11+
namespace Fortran::evaluate {
12+
13+
std::optional<ConstantSubscript> CheckDIM(
14+
FoldingContext &context, std::optional<ActualArgument> &arg, int rank) {
15+
if (arg) {
16+
if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg)}) {
17+
if (auto dimScalar{dimConst->GetScalarValue()}) {
18+
auto dim{dimScalar->ToInt64()};
19+
if (dim >= 1 && dim <= rank) {
20+
return {dim};
21+
} else {
22+
context.messages().Say(
23+
"DIM=%jd is not valid for an array of rank %d"_err_en_US,
24+
static_cast<std::intmax_t>(dim), rank);
25+
}
26+
}
27+
}
28+
}
29+
return std::nullopt;
30+
}
31+
32+
} // namespace Fortran::evaluate

flang/lib/Evaluate/fold-reduction.h

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
// TODO: ALL, ANY, COUNT, DOT_PRODUCT, FINDLOC, IALL, IANY, IPARITY,
10-
// NORM2, MAXLOC, MINLOC, PARITY, PRODUCT, SUM
9+
// TODO: DOT_PRODUCT, FINDLOC, NORM2, MAXLOC, MINLOC, PARITY
1110

1211
#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
1312
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_
@@ -16,6 +15,10 @@
1615

1716
namespace Fortran::evaluate {
1817

18+
// Folds & validates a DIM= actual argument.
19+
std::optional<ConstantSubscript> CheckDIM(
20+
FoldingContext &, std::optional<ActualArgument> &, int rank);
21+
1922
// Common preprocessing for reduction transformational intrinsic function
2023
// folding. If the intrinsic can have DIM= &/or MASK= arguments, extract
2124
// and check them. If a MASK= is present, apply it to the array data and
@@ -35,18 +38,7 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
3538
return std::nullopt;
3639
}
3740
if (dimIndex && arg.size() >= *dimIndex + 1 && arg[*dimIndex]) {
38-
if (auto *dimConst{
39-
Folder<SubscriptInteger>{context}.Folding(arg[*dimIndex])}) {
40-
if (auto dimScalar{dimConst->GetScalarValue()}) {
41-
dim.emplace(dimScalar->ToInt64());
42-
if (*dim < 1 || *dim > folded->Rank()) {
43-
context.messages().Say(
44-
"DIM=%jd is not valid for an array of rank %d"_err_en_US,
45-
static_cast<std::intmax_t>(*dim), folded->Rank());
46-
dim.reset();
47-
}
48-
}
49-
}
41+
dim = CheckDIM(context, arg[*dimIndex], folded->Rank());
5042
if (!dim) {
5143
return std::nullopt;
5244
}
@@ -96,8 +88,8 @@ static std::optional<Constant<T>> ProcessReductionArgs(FoldingContext &context,
9688

9789
// Generalized reduction to an array of one dimension fewer (w/ DIM=)
9890
// or to a scalar (w/o DIM=).
99-
template <typename T, typename ACCUMULATOR>
100-
static Constant<T> DoReduction(const Constant<T> &array,
91+
template <typename T, typename ACCUMULATOR, typename ARRAY>
92+
static Constant<T> DoReduction(const Constant<ARRAY> &array,
10193
std::optional<ConstantSubscript> &dim, const Scalar<T> &identity,
10294
ACCUMULATOR &accumulator) {
10395
ConstantSubscripts at{array.lbounds()};
@@ -154,7 +146,7 @@ static Expr<T> FoldMaxvalMinval(FoldingContext &context, FunctionRef<T> &&ref,
154146
element = array->At(at);
155147
}
156148
}};
157-
return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
149+
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
158150
}
159151
return Expr<T>{std::move(ref)};
160152
}
@@ -187,7 +179,7 @@ static Expr<T> FoldProduct(
187179
context.messages().Say(
188180
"PRODUCT() of %s data overflowed"_en_US, T::AsFortran());
189181
} else {
190-
return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
182+
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
191183
}
192184
}
193185
return Expr<T>{std::move(ref)};
@@ -226,7 +218,7 @@ static Expr<T> FoldSum(FoldingContext &context, FunctionRef<T> &&ref) {
226218
context.messages().Say(
227219
"SUM() of %s data overflowed"_en_US, T::AsFortran());
228220
} else {
229-
return Expr<T>{DoReduction(*array, dim, identity, accumulator)};
221+
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
230222
}
231223
}
232224
return Expr<T>{std::move(ref)};

flang/test/Evaluate/folding29.f90

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
! RUN: %python %S/test_folding.py %s %flang_fc1
2+
! Tests folding of COUNT()
3+
module m
4+
logical, parameter :: arr(3,4) = reshape([(modulo(j, 2) == 1, j = 1, size(arr))], shape(arr))
5+
logical, parameter :: test_1 = count([1, 2, 3, 2, 1] < [(j, j=1, 5)]) == 2
6+
logical, parameter :: test_2 = count(arr) == 6
7+
logical, parameter :: test_3 = all(count(arr, dim=1) == [2, 1, 2, 1])
8+
logical, parameter :: test_4 = all(count(arr, dim=2) == [2, 2, 2])
9+
logical, parameter :: test_5 = count(logical(arr, kind=1)) == 6
10+
logical, parameter :: test_6 = count(logical(arr, kind=2)) == 6
11+
end module

0 commit comments

Comments
 (0)