Skip to content

Commit 7380b75

Browse files
committed
Simplified generated code for fast-math 'nnan'.
1 parent 144fd50 commit 7380b75

File tree

5 files changed

+209
-27
lines changed

5 files changed

+209
-27
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ class MinMaxlocAsElementalConverter : public ReductionAsElementalConverter {
469469
// * 1 boolean indicating whether it is the first time
470470
// the mask is true.
471471
//
472-
// If precomputeFirst() returns true, then the boolean loop-carried
472+
// If useIsFirst() returns false, then the boolean loop-carried
473473
// value is not used.
474474
static constexpr unsigned maxNumReductions = Fortran::common::maxRank + 2;
475475
static constexpr bool isMax = std::is_same_v<T, hlfir::MaxlocOp>;
@@ -523,7 +523,7 @@ class MinMaxlocAsElementalConverter : public ReductionAsElementalConverter {
523523

524524
void
525525
checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
526-
if (precomputeFirst())
526+
if (!useIsFirst())
527527
assert(reductions.size() == getNumCoors() + 1 &&
528528
"invalid number of reductions for MINLOC/MAXLOC");
529529
else
@@ -540,15 +540,24 @@ class MinMaxlocAsElementalConverter : public ReductionAsElementalConverter {
540540
mlir::Value
541541
getIsFirst(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
542542
checkReductions(reductions);
543-
assert(!precomputeFirst() && "IsFirst predicate must not be used");
543+
assert(useIsFirst() && "IsFirst predicate must not be used");
544544
return reductions[getNumCoors() + 1];
545545
}
546546

547-
// Return true iff the reductions can be initialized
548-
// by reading the first element of the array (or its section).
549-
// If it returns false, then we use an auxiliary boolean
550-
// to identify the very first reduction update.
551-
bool precomputeFirst() const { return !getMask(); }
547+
// Return true iff the input can contain NaNs, and they should be
548+
// honored, such that all-NaNs input must produce the location
549+
// of the first unmasked NaN.
550+
bool honorNans() const {
551+
return !static_cast<bool>(getFastMath() & mlir::arith::FastMathFlags::nnan);
552+
}
553+
554+
// Return true iff we have to use the loop-carried IsFirst predicate.
555+
// If there is no mask, we can initialize the reductions using
556+
// the first elements of the input.
557+
// If NaNs are not honored, we can initialize the starting MIN/MAX
558+
// value to +/-LARGEST; the coordinates are guaranteed to be updated
559+
// properly for non-empty input without NaNs.
560+
bool useIsFirst() const { return getMask() && honorNans(); }
552561
};
553562

554563
template <typename T>
@@ -557,9 +566,10 @@ MinMaxlocAsElementalConverter<T>::genReductionInitValues(
557566
mlir::ValueRange oneBasedIndices,
558567
const llvm::SmallVectorImpl<mlir::Value> &extents) {
559568
fir::IfOp ifOp;
560-
if (precomputeFirst()) {
569+
if (!useIsFirst() && honorNans()) {
561570
// Check if we can load the value of the first element in the array
562571
// or its section (for partial reduction).
572+
assert(!getMask() && "cannot fetch first element when mask is present");
563573
assert(extents.size() == getNumCoors() &&
564574
"wrong number of extents for MINLOC/MAXLOC reduction");
565575
mlir::Value isNotEmpty = genIsNotEmptyArrayExtents(loc, builder, extents);
@@ -600,7 +610,7 @@ MinMaxlocAsElementalConverter<T>::genReductionInitValues(
600610
builder.create<fir::ResultOp>(loc, result);
601611
builder.setInsertionPointAfter(ifOp);
602612
result = ifOp.getResults();
603-
} else {
613+
} else if (useIsFirst()) {
604614
// Initial value for isFirst predicate. It is switched to false,
605615
// when the reduction update dynamically happens inside the reduction
606616
// loop.
@@ -621,7 +631,7 @@ MinMaxlocAsElementalConverter<T>::reduceOneElement(
621631
hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
622632
mlir::Value cmp = genMinMaxComparison<isMax>(loc, builder, elementValue,
623633
getCurrentMinMax(currentValue));
624-
if (!precomputeFirst()) {
634+
if (useIsFirst()) {
625635
// If isFirst is true, then do the reduction update regardless
626636
// of the FP comparison.
627637
cmp =
@@ -652,7 +662,7 @@ MinMaxlocAsElementalConverter<T>::reduceOneElement(
652662
loc, cmp, elementValue, getCurrentMinMax(currentValue));
653663
newIndices.push_back(newMinMax);
654664

655-
if (!precomputeFirst()) {
665+
if (useIsFirst()) {
656666
mlir::Value newIsFirst = builder.createBool(loc, false);
657667
newIndices.push_back(newIsFirst);
658668
}
@@ -746,7 +756,7 @@ class MinMaxvalAsElementalConverter
746756
//
747757
// The boolean flag is used to replace the initial value
748758
// with the first input element even if it is NaN.
749-
// If precomputeFirst() returns true, then the boolean loop-carried
759+
// If useIsFirst() returns false, then the boolean loop-carried
750760
// value is not used.
751761
static constexpr bool isMax = std::is_same_v<T, hlfir::MaxvalOp>;
752762
using Base = NumericReductionAsElementalConverterBase<T>;
@@ -781,13 +791,13 @@ class MinMaxvalAsElementalConverter
781791
mlir::Value currentMinMax = getCurrentMinMax(currentValue);
782792
mlir::Value cmp =
783793
genMinMaxComparison<isMax>(loc, builder, elementValue, currentMinMax);
784-
if (!precomputeFirst())
794+
if (useIsFirst())
785795
cmp = builder.create<mlir::arith::OrIOp>(loc, cmp,
786796
getIsFirst(currentValue));
787797
mlir::Value newMinMax = builder.create<mlir::arith::SelectOp>(
788798
loc, cmp, elementValue, currentMinMax);
789799
result.push_back(newMinMax);
790-
if (!precomputeFirst())
800+
if (useIsFirst())
791801
result.push_back(builder.createBool(loc, false));
792802
return result;
793803
}
@@ -813,17 +823,25 @@ class MinMaxvalAsElementalConverter
813823
mlir::Value
814824
getIsFirst(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
815825
this->checkReductions(reductions);
816-
assert(!precomputeFirst() && "IsFirst predicate must not be used");
826+
assert(useIsFirst() && "IsFirst predicate must not be used");
817827
return reductions[1];
818828
}
819829

820-
// Return true iff the reductions can be initialized
821-
// by reading the first element of the array (or its section).
822-
// If it returns false, then we use an auxiliary boolean
823-
// to identify the very first reduction update.
824-
bool precomputeFirst() const { return !this->getMask(); }
830+
// Return true iff the input can contain NaNs, and they should be
831+
// honored, such that all-NaNs input must produce NaN result.
832+
bool honorNans() const {
833+
return !static_cast<bool>(this->getFastMath() &
834+
mlir::arith::FastMathFlags::nnan);
835+
}
836+
837+
// Return true iff we have to use the loop-carried IsFirst predicate.
838+
// If there is no mask, we can initialize the reductions using
839+
// the first elements of the input.
840+
// If NaNs are not honored, we can initialize the starting MIN/MAX
841+
// value to +/-LARGEST.
842+
bool useIsFirst() const { return this->getMask() && honorNans(); }
825843

826-
std::size_t getNumReductions() const { return precomputeFirst() ? 1 : 2; }
844+
std::size_t getNumReductions() const { return useIsFirst() ? 2 : 1; }
827845
};
828846

829847
template <typename T>
@@ -836,12 +854,14 @@ MinMaxvalAsElementalConverter<T>::genReductionInitValues(
836854
mlir::Location loc = this->loc;
837855

838856
fir::IfOp ifOp;
839-
if (precomputeFirst()) {
857+
if (!useIsFirst() && honorNans()) {
840858
// Check if we can load the value of the first element in the array
841859
// or its section (for partial reduction).
842-
assert(extents.size() == this->isTotalReduction()
843-
? this->getSourceRank()
844-
: 1u && "wrong number of extents for MINVAL/MAXVAL reduction");
860+
assert(!this->getMask() &&
861+
"cannot fetch first element when mask is present");
862+
assert(extents.size() ==
863+
(this->isTotalReduction() ? this->getSourceRank() : 1u) &&
864+
"wrong number of extents for MINVAL/MAXVAL reduction");
845865
mlir::Value isNotEmpty = genIsNotEmptyArrayExtents(loc, builder, extents);
846866
llvm::SmallVector<mlir::Value> indices = genFirstElementIndicesForReduction(
847867
loc, builder, this->isTotalReduction(), this->getConstDim(),
@@ -867,7 +887,7 @@ MinMaxvalAsElementalConverter<T>::genReductionInitValues(
867887
builder.create<fir::ResultOp>(loc, result);
868888
builder.setInsertionPointAfter(ifOp);
869889
result = ifOp.getResults();
870-
} else {
890+
} else if (useIsFirst()) {
871891
// Initial value for isFirst predicate. It is switched to false,
872892
// when the reduction update dynamically happens inside the reduction
873893
// loop.

flang/test/HLFIR/simplify-hlfir-intrinsics-maxloc.fir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,56 @@ func.func @test_partial_var_nomask(%input: !fir.box<!fir.array<?x?x?xf32>>) -> !
417417
// CHECK: return %[[VAL_11]] : !hlfir.expr<?x?xi32>
418418
// CHECK: }
419419

420+
// Test that 'nnan' allows using -LARGEST value as the reduction init.
421+
func.func @test_total_expr_nnan(%input: !hlfir.expr<?x?x?xf32>) -> !hlfir.expr<3xi32> {
422+
%0 = hlfir.maxloc %input {fastmath = #arith.fastmath<nnan>} : (!hlfir.expr<?x?x?xf32>) -> !hlfir.expr<3xi32>
423+
return %0 : !hlfir.expr<3xi32>
424+
}
425+
// CHECK-LABEL: func.func @test_total_expr_nnan(
426+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !hlfir.expr<?x?x?xf32>) -> !hlfir.expr<3xi32> {
427+
// CHECK: %[[VAL_1:.*]] = arith.constant false
428+
// CHECK: %[[VAL_2:.*]] = arith.constant 3 : index
429+
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
430+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
431+
// CHECK: %[[VAL_5:.*]] = arith.constant -3.40282347E+38 : f32
432+
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
433+
// CHECK: %[[VAL_7:.*]] = fir.alloca !fir.array<3xi32>
434+
// CHECK: %[[VAL_8:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?x?xf32>) -> !fir.shape<3>
435+
// CHECK: %[[VAL_9:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 0 : index} : (!fir.shape<3>) -> index
436+
// CHECK: %[[VAL_10:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 1 : index} : (!fir.shape<3>) -> index
437+
// CHECK: %[[VAL_11:.*]] = hlfir.get_extent %[[VAL_8]] {dim = 2 : index} : (!fir.shape<3>) -> index
438+
// CHECK: %[[VAL_12:.*]]:4 = fir.do_loop %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_11]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_6]], %[[VAL_15:.*]] = %[[VAL_6]], %[[VAL_16:.*]] = %[[VAL_6]], %[[VAL_17:.*]] = %[[VAL_5]]) -> (i32, i32, i32, f32) {
439+
// CHECK: %[[VAL_18:.*]]:4 = fir.do_loop %[[VAL_19:.*]] = %[[VAL_4]] to %[[VAL_10]] step %[[VAL_4]] iter_args(%[[VAL_20:.*]] = %[[VAL_14]], %[[VAL_21:.*]] = %[[VAL_15]], %[[VAL_22:.*]] = %[[VAL_16]], %[[VAL_23:.*]] = %[[VAL_17]]) -> (i32, i32, i32, f32) {
440+
// CHECK: %[[VAL_24:.*]]:4 = fir.do_loop %[[VAL_25:.*]] = %[[VAL_4]] to %[[VAL_9]] step %[[VAL_4]] iter_args(%[[VAL_26:.*]] = %[[VAL_20]], %[[VAL_27:.*]] = %[[VAL_21]], %[[VAL_28:.*]] = %[[VAL_22]], %[[VAL_29:.*]] = %[[VAL_23]]) -> (i32, i32, i32, f32) {
441+
// CHECK: %[[VAL_30:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_25]], %[[VAL_19]], %[[VAL_13]] : (!hlfir.expr<?x?x?xf32>, index, index, index) -> f32
442+
// CHECK: %[[VAL_31:.*]] = arith.cmpf ogt, %[[VAL_30]], %[[VAL_29]] fastmath<nnan> : f32
443+
// CHECK: %[[VAL_32:.*]] = arith.cmpf une, %[[VAL_29]], %[[VAL_29]] fastmath<nnan> : f32
444+
// CHECK: %[[VAL_33:.*]] = arith.cmpf oeq, %[[VAL_30]], %[[VAL_30]] fastmath<nnan> : f32
445+
// CHECK: %[[VAL_34:.*]] = arith.andi %[[VAL_32]], %[[VAL_33]] : i1
446+
// CHECK: %[[VAL_35:.*]] = arith.ori %[[VAL_31]], %[[VAL_34]] : i1
447+
// CHECK: %[[VAL_36:.*]] = fir.convert %[[VAL_25]] : (index) -> i32
448+
// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_26]] : i32
449+
// CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_19]] : (index) -> i32
450+
// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_35]], %[[VAL_38]], %[[VAL_27]] : i32
451+
// CHECK: %[[VAL_40:.*]] = fir.convert %[[VAL_13]] : (index) -> i32
452+
// CHECK: %[[VAL_41:.*]] = arith.select %[[VAL_35]], %[[VAL_40]], %[[VAL_28]] : i32
453+
// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_35]], %[[VAL_30]], %[[VAL_29]] : f32
454+
// CHECK: fir.result %[[VAL_37]], %[[VAL_39]], %[[VAL_41]], %[[VAL_42]] : i32, i32, i32, f32
455+
// CHECK: }
456+
// CHECK: fir.result %[[VAL_43:.*]]#0, %[[VAL_43]]#1, %[[VAL_43]]#2, %[[VAL_43]]#3 : i32, i32, i32, f32
457+
// CHECK: }
458+
// CHECK: fir.result %[[VAL_44:.*]]#0, %[[VAL_44]]#1, %[[VAL_44]]#2, %[[VAL_44]]#3 : i32, i32, i32, f32
459+
// CHECK: }
460+
// CHECK: %[[VAL_45:.*]] = hlfir.designate %[[VAL_7]] (%[[VAL_4]]) : (!fir.ref<!fir.array<3xi32>>, index) -> !fir.ref<i32>
461+
// CHECK: hlfir.assign %[[VAL_46:.*]]#0 to %[[VAL_45]] : i32, !fir.ref<i32>
462+
// CHECK: %[[VAL_47:.*]] = hlfir.designate %[[VAL_7]] (%[[VAL_3]]) : (!fir.ref<!fir.array<3xi32>>, index) -> !fir.ref<i32>
463+
// CHECK: hlfir.assign %[[VAL_46]]#1 to %[[VAL_47]] : i32, !fir.ref<i32>
464+
// CHECK: %[[VAL_48:.*]] = hlfir.designate %[[VAL_7]] (%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, index) -> !fir.ref<i32>
465+
// CHECK: hlfir.assign %[[VAL_46]]#2 to %[[VAL_48]] : i32, !fir.ref<i32>
466+
// CHECK: %[[VAL_49:.*]] = hlfir.as_expr %[[VAL_7]] move %[[VAL_1]] : (!fir.ref<!fir.array<3xi32>>, i1) -> !hlfir.expr<3xi32>
467+
// CHECK: return %[[VAL_49]] : !hlfir.expr<3xi32>
468+
// CHECK: }
469+
420470
// Character comparisons are not supported yet.
421471
func.func @test_character(%input: !fir.box<!fir.array<?x!fir.char<1>>>) -> !hlfir.expr<1xi32> {
422472
%0 = hlfir.maxloc %input : (!fir.box<!fir.array<?x!fir.char<1>>>) -> !hlfir.expr<1xi32>

flang/test/HLFIR/simplify-hlfir-intrinsics-maxval.fir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,34 @@ func.func @test_total_var_nomask(%input: !fir.box<!fir.array<?x?xf16>>) -> f16 {
268268
// CHECK: }
269269
// CHECK: return %[[VAL_14]] : f16
270270
// CHECK: }
271+
272+
// Test that 'nnan' allows using -LARGEST value as the reduction init.
273+
func.func @test_partial_expr_nnan(%input: !hlfir.expr<?x?xf64>) -> !hlfir.expr<?xf64> {
274+
%dim = arith.constant 1 : i32
275+
%0 = hlfir.maxval %input dim %dim {fastmath = #arith.fastmath<nnan>} : (!hlfir.expr<?x?xf64>, i32) -> !hlfir.expr<?xf64>
276+
return %0 : !hlfir.expr<?xf64>
277+
}
278+
// CHECK-LABEL: func.func @test_partial_expr_nnan(
279+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !hlfir.expr<?x?xf64>) -> !hlfir.expr<?xf64> {
280+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
281+
// CHECK: %[[VAL_2:.*]] = arith.constant -1.7976931348623157E+308 : f64
282+
// CHECK: %[[VAL_3:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x?xf64>) -> !fir.shape<2>
283+
// CHECK: %[[VAL_4:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 0 : index} : (!fir.shape<2>) -> index
284+
// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_3]] {dim = 1 : index} : (!fir.shape<2>) -> index
285+
// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
286+
// CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf64> {
287+
// CHECK: ^bb0(%[[VAL_8:.*]]: index):
288+
// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_4]] step %[[VAL_1]] iter_args(%[[VAL_11:.*]] = %[[VAL_2]]) -> (f64) {
289+
// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_8]] : (!hlfir.expr<?x?xf64>, index, index) -> f64
290+
// CHECK: %[[VAL_13:.*]] = arith.cmpf ogt, %[[VAL_12]], %[[VAL_11]] fastmath<nnan> : f64
291+
// CHECK: %[[VAL_14:.*]] = arith.cmpf une, %[[VAL_11]], %[[VAL_11]] fastmath<nnan> : f64
292+
// CHECK: %[[VAL_15:.*]] = arith.cmpf oeq, %[[VAL_12]], %[[VAL_12]] fastmath<nnan> : f64
293+
// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_14]], %[[VAL_15]] : i1
294+
// CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_13]], %[[VAL_16]] : i1
295+
// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_12]], %[[VAL_11]] : f64
296+
// CHECK: fir.result %[[VAL_18]] : f64
297+
// CHECK: }
298+
// CHECK: hlfir.yield_element %[[VAL_9]] : f64
299+
// CHECK: }
300+
// CHECK: return %[[VAL_7]] : !hlfir.expr<?xf64>
301+
// CHECK: }

0 commit comments

Comments
 (0)