@@ -270,7 +270,8 @@ template <typename T, int MASK_KIND> class CountAccumulator {
270270
271271public:
272272 CountAccumulator (const Constant<MaskT> &mask) : mask_{mask} {}
273- void operator ()(Scalar<T> &element, const ConstantSubscripts &at) {
273+ void operator ()(
274+ Scalar<T> &element, const ConstantSubscripts &at, bool /* first*/ ) {
274275 if (mask_.At (at).IsTrue ()) {
275276 auto incremented{element.AddSigned (Scalar<T>{1 })};
276277 overflow_ |= incremented.overflow ;
@@ -287,22 +288,20 @@ template <typename T, int MASK_KIND> class CountAccumulator {
287288
288289template <typename T, int maskKind>
289290static Expr<T> FoldCount (FoldingContext &context, FunctionRef<T> &&ref) {
290- using LogicalResult = Type<TypeCategory::Logical, maskKind>;
291+ using KindLogical = Type<TypeCategory::Logical, maskKind>;
291292 static_assert (T::category == TypeCategory::Integer);
292- ActualArguments &arg{ref.arguments ()};
293- if (const Constant<LogicalResult> *mask{arg.empty ()
294- ? nullptr
295- : Folder<LogicalResult>{context}.Folding (arg[0 ])}) {
296- std::optional<int > dim;
297- if (CheckReductionDIM (dim, context, arg, 1 , mask->Rank ())) {
298- CountAccumulator<T, maskKind> accumulator{*mask};
299- Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
300- if (accumulator.overflow ()) {
301- context.messages ().Say (
302- " Result of intrinsic function COUNT overflows its result type" _warn_en_US);
303- }
304- return Expr<T>{std::move (result)};
293+ std::optional<int > dim;
294+ if (std::optional<ArrayAndMask<KindLogical>> arrayAndMask{
295+ ProcessReductionArgs<KindLogical>(
296+ context, ref.arguments (), dim, /* ARRAY=*/ 0 , /* DIM=*/ 1 )}) {
297+ CountAccumulator<T, maskKind> accumulator{arrayAndMask->array };
298+ Constant<T> result{DoReduction<T>(arrayAndMask->array , arrayAndMask->mask ,
299+ dim, Scalar<T>{}, accumulator)};
300+ if (accumulator.overflow ()) {
301+ context.messages ().Say (
302+ " Result of intrinsic function COUNT overflows its result type" _warn_en_US);
305303 }
304+ return Expr<T>{std::move (result)};
306305 }
307306 return Expr<T>{std::move (ref)};
308307}
@@ -395,7 +394,7 @@ template <WhichLocation WHICH> class LocationHelper {
395394 for (ConstantSubscript k{0 }; k < dimLength;
396395 ++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
397396 if ((!mask || mask->At (maskAt).IsTrue ()) &&
398- IsHit (array->At (at), value, relation)) {
397+ IsHit (array->At (at), value, relation, back )) {
399398 hit = at[zbDim];
400399 if constexpr (WHICH == WhichLocation::Findloc) {
401400 if (!back) {
@@ -422,7 +421,7 @@ template <WhichLocation WHICH> class LocationHelper {
422421 for (ConstantSubscript j{0 }; j < n; ++j, array->IncrementSubscripts (at),
423422 mask && mask->IncrementSubscripts (maskAt)) {
424423 if ((!mask || mask->At (maskAt).IsTrue ()) &&
425- IsHit (array->At (at), value, relation)) {
424+ IsHit (array->At (at), value, relation, back )) {
426425 resultIndices = at;
427426 if constexpr (WHICH == WhichLocation::Findloc) {
428427 if (!back) {
@@ -444,7 +443,8 @@ template <WhichLocation WHICH> class LocationHelper {
444443 template <typename T>
445444 bool IsHit (typename Constant<T>::Element element,
446445 std::optional<Constant<T>> &value,
447- [[maybe_unused]] RelationalOperator relation) const {
446+ [[maybe_unused]] RelationalOperator relation,
447+ [[maybe_unused]] bool back) const {
448448 std::optional<Expr<LogicalResult>> cmp;
449449 bool result{true };
450450 if (value) {
@@ -455,8 +455,19 @@ template <WhichLocation WHICH> class LocationHelper {
455455 Expr<T>{LogicalOperation<T::kind>{LogicalOperator::Eqv,
456456 Expr<T>{Constant<T>{element}}, Expr<T>{Constant<T>{*value}}}}));
457457 } else { // compare array(at) to value
458- cmp.emplace (PackageRelation (relation, Expr<T>{Constant<T>{element}},
459- Expr<T>{Constant<T>{*value}}));
458+ if constexpr (T::category == TypeCategory::Real &&
459+ (WHICH == WhichLocation::Maxloc ||
460+ WHICH == WhichLocation::Minloc)) {
461+ if (value && value->GetScalarValue ().value ().IsNotANumber () &&
462+ (back || !element.IsNotANumber ())) {
463+ // Replace NaN
464+ cmp.emplace (Constant<LogicalResult>{Scalar<LogicalResult>{true }});
465+ }
466+ }
467+ if (!cmp) {
468+ cmp.emplace (PackageRelation (relation, Expr<T>{Constant<T>{element}},
469+ Expr<T>{Constant<T>{*value}}));
470+ }
460471 }
461472 Expr<LogicalResult> folded{Fold (context_, std::move (*cmp))};
462473 result = GetScalarConstantValue<LogicalResult>(folded).value ().IsTrue ();
@@ -523,11 +534,12 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
523534 Scalar<T> identity) {
524535 static_assert (T::category == TypeCategory::Integer);
525536 std::optional<int > dim;
526- if (std::optional<Constant <T>> array {
527- ProcessReductionArgs<T>(context, ref.arguments (), dim, identity,
537+ if (std::optional<ArrayAndMask <T>> arrayAndMask {
538+ ProcessReductionArgs<T>(context, ref.arguments (), dim,
528539 /* ARRAY=*/ 0 , /* DIM=*/ 1 , /* MASK=*/ 2 )}) {
529- OperationAccumulator<T> accumulator{*array, operation};
530- return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
540+ OperationAccumulator<T> accumulator{arrayAndMask->array , operation};
541+ return Expr<T>{DoReduction<T>(
542+ arrayAndMask->array , arrayAndMask->mask , dim, identity, accumulator)};
531543 }
532544 return Expr<T>{std::move (ref)};
533545}
0 commit comments