@@ -702,8 +702,53 @@ static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
702702 return reduction;
703703}
704704
705- // / Given a reduction operation with an elemental mask, attempt to generate a
706- // / do-loop to perform the operation inline.
705+ auto makeMinMaxInitValGenerator (bool isMax) {
706+ return [isMax](fir::FirOpBuilder builder, mlir::Location loc,
707+ mlir::Type elementType) -> mlir::Value {
708+ if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
709+ const llvm::fltSemantics &sem = ty.getFloatSemantics ();
710+ llvm::APFloat limit = llvm::APFloat::getInf (sem, /* Negative=*/ isMax);
711+ return builder.createRealConstant (loc, elementType, limit);
712+ }
713+ unsigned bits = elementType.getIntOrFloatBitWidth ();
714+ int64_t limitInt =
715+ isMax ? llvm::APInt::getSignedMinValue (bits).getSExtValue ()
716+ : llvm::APInt::getSignedMaxValue (bits).getSExtValue ();
717+ return builder.createIntegerConstant (loc, elementType, limitInt);
718+ };
719+ }
720+
721+ mlir::Value generateMinMaxComparison (fir::FirOpBuilder builder,
722+ mlir::Location loc, mlir::Value elem,
723+ mlir::Value reduction, bool isMax) {
724+ if (mlir::isa<mlir::FloatType>(reduction.getType ())) {
725+ // For FP reductions we want the first smallest value to be used, that
726+ // is not NaN. A OGL/OLT condition will usually work for this unless all
727+ // the values are Nan or Inf. This follows the same logic as
728+ // NumericCompare for Minloc/Maxlox in extrema.cpp.
729+ mlir::Value cmp = builder.create <mlir::arith::CmpFOp>(
730+ loc,
731+ isMax ? mlir::arith::CmpFPredicate::OGT
732+ : mlir::arith::CmpFPredicate::OLT,
733+ elem, reduction);
734+ mlir::Value cmpNan = builder.create <mlir::arith::CmpFOp>(
735+ loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
736+ mlir::Value cmpNan2 = builder.create <mlir::arith::CmpFOp>(
737+ loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
738+ cmpNan = builder.create <mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
739+ return builder.create <mlir::arith::OrIOp>(loc, cmp, cmpNan);
740+ } else if (mlir::isa<mlir::IntegerType>(reduction.getType ())) {
741+ return builder.create <mlir::arith::CmpIOp>(
742+ loc,
743+ isMax ? mlir::arith::CmpIPredicate::sgt
744+ : mlir::arith::CmpIPredicate::slt,
745+ elem, reduction);
746+ }
747+ llvm_unreachable (" unsupported type" );
748+ }
749+
750+ // / Given a reduction operation with an elemental/designate source, attempt to
751+ // / generate a do-loop to perform the operation inline.
707752// / %e = hlfir.elemental %shape unordered
708753// / %r = hlfir.count %e
709754// / =>
@@ -712,17 +757,66 @@ static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
712757// / %c = <reduce count> %i
713758// / fir.result %c
714759template <typename Op>
715- class ReductionElementalConversion : public mlir ::OpRewritePattern<Op> {
760+ class ReductionConversion : public mlir ::OpRewritePattern<Op> {
716761public:
717762 using mlir::OpRewritePattern<Op>::OpRewritePattern;
718763
719764 llvm::LogicalResult
720765 matchAndRewrite (Op op, mlir::PatternRewriter &rewriter) const override {
721766 mlir::Location loc = op.getLoc ();
722- hlfir::ElementalOp elemental =
723- op.getMask ().template getDefiningOp <hlfir::ElementalOp>();
724- if (!elemental || op.getDim ())
725- return rewriter.notifyMatchFailure (op, " Did not find valid elemental" );
767+ // Select source and validate its arguments.
768+ mlir::Value source;
769+ bool valid = false ;
770+ if constexpr (std::is_same_v<Op, hlfir::AnyOp> ||
771+ std::is_same_v<Op, hlfir::AllOp> ||
772+ std::is_same_v<Op, hlfir::CountOp>) {
773+ source = op.getMask ();
774+ valid = !op.getDim ();
775+ } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
776+ std::is_same_v<Op, hlfir::MinvalOp>) {
777+ source = op.getArray ();
778+ valid = !op.getDim () && !op.getMask ();
779+ } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
780+ std::is_same_v<Op, hlfir::MinlocOp>) {
781+ source = op.getArray ();
782+ valid = !op.getDim () && !op.getMask () && !op.getBack ();
783+ }
784+ if (!valid)
785+ return rewriter.notifyMatchFailure (
786+ op, " Currently does not accept optional arguments" );
787+
788+ hlfir::ElementalOp elemental;
789+ hlfir::DesignateOp designate;
790+ mlir::Value shape;
791+ if ((elemental = source.template getDefiningOp <hlfir::ElementalOp>())) {
792+ shape = elemental.getOperand (0 );
793+ } else if ((designate =
794+ source.template getDefiningOp <hlfir::DesignateOp>())) {
795+ shape = designate.getShape ();
796+ } else {
797+ return rewriter.notifyMatchFailure (op, " Did not find valid argument" );
798+ }
799+
800+ auto inlineSource =
801+ [elemental, &designate](
802+ fir::FirOpBuilder builder, mlir::Location loc,
803+ const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
804+ if (elemental) {
805+ // Inline the elemental and get the value from it.
806+ auto yield = inlineElementalOp (loc, builder, elemental, indices);
807+ auto tmp = yield.getElementValue ();
808+ yield->erase ();
809+ return tmp;
810+ }
811+ if (designate) {
812+ // Create a designator over designator, then load the reference.
813+ auto resEntity = hlfir::Entity{designate.getResult ()};
814+ auto tmp = builder.create <hlfir::DesignateOp>(
815+ loc, getVariableElementType (resEntity), designate, indices);
816+ return builder.create <fir::LoadOp>(loc, tmp);
817+ }
818+ llvm_unreachable (" unsupported type" );
819+ };
726820
727821 fir::KindMapping kindMap =
728822 fir::getKindMapping (op->template getParentOfType <mlir::ModuleOp>());
@@ -732,77 +826,88 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
732826 GenBodyFn genBodyFn;
733827 if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
734828 init = builder.createIntegerConstant (loc, builder.getI1Type (), 0 );
735- genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
736- mlir::Value reduction,
737- const llvm::SmallVectorImpl<mlir::Value> &indices)
829+ genBodyFn =
830+ [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
831+ mlir::Value reduction,
832+ const llvm::SmallVectorImpl<mlir::Value> &indices)
738833 -> mlir::Value {
739- // Inline the elemental and get the condition from it.
740- auto yield = inlineElementalOp (loc, builder, elemental, indices);
741- mlir::Value cond = builder.create <fir::ConvertOp>(
742- loc, builder.getI1Type (), yield.getElementValue ());
743- yield->erase ();
744-
745834 // Conditionally set the reduction variable.
835+ mlir::Value cond = builder.create <fir::ConvertOp>(
836+ loc, builder.getI1Type (), inlineSource (builder, loc, indices));
746837 return builder.create <mlir::arith::OrIOp>(loc, reduction, cond);
747838 };
748839 } else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
749840 init = builder.createIntegerConstant (loc, builder.getI1Type (), 1 );
750- genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
751- mlir::Value reduction,
752- const llvm::SmallVectorImpl<mlir::Value> &indices)
841+ genBodyFn =
842+ [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
843+ mlir::Value reduction,
844+ const llvm::SmallVectorImpl<mlir::Value> &indices)
753845 -> mlir::Value {
754- // Inline the elemental and get the condition from it.
755- auto yield = inlineElementalOp (loc, builder, elemental, indices);
756- mlir::Value cond = builder.create <fir::ConvertOp>(
757- loc, builder.getI1Type (), yield.getElementValue ());
758- yield->erase ();
759-
760846 // Conditionally set the reduction variable.
847+ mlir::Value cond = builder.create <fir::ConvertOp>(
848+ loc, builder.getI1Type (), inlineSource (builder, loc, indices));
761849 return builder.create <mlir::arith::AndIOp>(loc, reduction, cond);
762850 };
763851 } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
764852 init = builder.createIntegerConstant (loc, op.getType (), 0 );
765- genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
766- mlir::Value reduction,
767- const llvm::SmallVectorImpl<mlir::Value> &indices)
853+ genBodyFn =
854+ [inlineSource](fir::FirOpBuilder builder, mlir::Location loc,
855+ mlir::Value reduction,
856+ const llvm::SmallVectorImpl<mlir::Value> &indices)
768857 -> mlir::Value {
769- // Inline the elemental and get the condition from it.
770- auto yield = inlineElementalOp (loc, builder, elemental, indices);
771- mlir::Value cond = builder.create <fir::ConvertOp>(
772- loc, builder.getI1Type (), yield.getElementValue ());
773- yield->erase ();
774-
775858 // Conditionally add one to the current value
859+ mlir::Value cond = builder.create <fir::ConvertOp>(
860+ loc, builder.getI1Type (), inlineSource (builder, loc, indices));
776861 mlir::Value one =
777862 builder.createIntegerConstant (loc, reduction.getType (), 1 );
778863 mlir::Value add1 =
779864 builder.create <mlir::arith::AddIOp>(loc, reduction, one);
780865 return builder.create <mlir::arith::SelectOp>(loc, cond, add1,
781866 reduction);
782867 };
868+ } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> ||
869+ std::is_same_v<Op, hlfir::MinlocOp>) {
870+ // TODO: implement minloc/maxloc conversion.
871+ return rewriter.notifyMatchFailure (
872+ op, " Currently minloc/maxloc is not handled" );
873+ } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
874+ std::is_same_v<Op, hlfir::MinvalOp>) {
875+ bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
876+ init = makeMinMaxInitValGenerator (isMax)(builder, loc, op.getType ());
877+ genBodyFn = [inlineSource,
878+ isMax](fir::FirOpBuilder builder, mlir::Location loc,
879+ mlir::Value reduction,
880+ const llvm::SmallVectorImpl<mlir::Value> &indices)
881+ -> mlir::Value {
882+ mlir::Value val = inlineSource (builder, loc, indices);
883+ mlir::Value cmp =
884+ generateMinMaxComparison (builder, loc, val, reduction, isMax);
885+ return builder.create <mlir::arith::SelectOp>(loc, cmp, val, reduction);
886+ };
783887 } else {
784- return mlir::failure ( );
888+ llvm_unreachable ( " unsupported type " );
785889 }
786890
787- mlir::Value res = generateReductionLoop (builder, loc, init,
788- elemental. getOperand ( 0 ) , genBodyFn);
891+ mlir::Value res =
892+ generateReductionLoop (builder, loc, init, shape , genBodyFn);
789893 if (res.getType () != op.getType ())
790894 res = builder.create <fir::ConvertOp>(loc, op.getType (), res);
791895
792- // Check if the op was the only user of the elemental (apart from a
793- // destroy), and remove it if so.
794- mlir::Operation::user_range elemUsers = elemental->getUsers ();
795- hlfir::DestroyOp elemDestroy;
796- if (std::distance (elemUsers.begin (), elemUsers.end ()) == 2 ) {
797- elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin ());
798- if (!elemDestroy)
799- elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin ());
896+ // Check if the op was the only user of the source (apart from a destroy),
897+ // and remove it if so.
898+ mlir::Operation *sourceOp = source.getDefiningOp ();
899+ mlir::Operation::user_range srcUsers = sourceOp->getUsers ();
900+ hlfir::DestroyOp srcDestroy;
901+ if (std::distance (srcUsers.begin (), srcUsers.end ()) == 2 ) {
902+ srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*srcUsers.begin ());
903+ if (!srcDestroy)
904+ srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++srcUsers.begin ());
800905 }
801906
802907 rewriter.replaceOp (op, res);
803- if (elemDestroy ) {
804- rewriter.eraseOp (elemDestroy );
805- rewriter.eraseOp (elemental );
908+ if (srcDestroy ) {
909+ rewriter.eraseOp (srcDestroy );
910+ rewriter.eraseOp (sourceOp );
806911 }
807912 return mlir::success ();
808913 }
@@ -813,7 +918,7 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
813918// %e = hlfir.elemental %shape ({ ... })
814919// %m = hlfir.minloc %array mask %e
815920template <typename Op>
816- class MinMaxlocElementalConversion : public mlir ::OpRewritePattern<Op> {
921+ class ReductionMaskConversion : public mlir ::OpRewritePattern<Op> {
817922public:
818923 using mlir::OpRewritePattern<Op>::OpRewritePattern;
819924
@@ -848,19 +953,7 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
848953 loc, fir::SequenceType::get (
849954 rank, hlfir::getFortranElementType (mloc.getType ())));
850955
851- auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
852- mlir::Type elementType) {
853- if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
854- const llvm::fltSemantics &sem = ty.getFloatSemantics ();
855- llvm::APFloat limit = llvm::APFloat::getInf (sem, /* Negative=*/ isMax);
856- return builder.createRealConstant (loc, elementType, limit);
857- }
858- unsigned bits = elementType.getIntOrFloatBitWidth ();
859- int64_t limitInt =
860- isMax ? llvm::APInt::getSignedMinValue (bits).getSExtValue ()
861- : llvm::APInt::getSignedMaxValue (bits).getSExtValue ();
862- return builder.createIntegerConstant (loc, elementType, limitInt);
863- };
956+ auto init = makeMinMaxInitValGenerator (isMax);
864957
865958 auto genBodyOp =
866959 [&rank, &resultArr, &elemental, isMax](
@@ -900,33 +993,8 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
900993 mlir::Value elem = builder.create <fir::LoadOp>(loc, addr);
901994
902995 // Compare with the max reduction value
903- mlir::Value cmp;
904- if (mlir::isa<mlir::FloatType>(elementType)) {
905- // For FP reductions we want the first smallest value to be used, that
906- // is not NaN. A OGL/OLT condition will usually work for this unless all
907- // the values are Nan or Inf. This follows the same logic as
908- // NumericCompare for Minloc/Maxlox in extrema.cpp.
909- cmp = builder.create <mlir::arith::CmpFOp>(
910- loc,
911- isMax ? mlir::arith::CmpFPredicate::OGT
912- : mlir::arith::CmpFPredicate::OLT,
913- elem, reduction);
914-
915- mlir::Value cmpNan = builder.create <mlir::arith::CmpFOp>(
916- loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
917- mlir::Value cmpNan2 = builder.create <mlir::arith::CmpFOp>(
918- loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
919- cmpNan = builder.create <mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
920- cmp = builder.create <mlir::arith::OrIOp>(loc, cmp, cmpNan);
921- } else if (mlir::isa<mlir::IntegerType>(elementType)) {
922- cmp = builder.create <mlir::arith::CmpIOp>(
923- loc,
924- isMax ? mlir::arith::CmpIPredicate::sgt
925- : mlir::arith::CmpIPredicate::slt,
926- elem, reduction);
927- } else {
928- llvm_unreachable (" unsupported type" );
929- }
996+ mlir::Value cmp =
997+ generateMinMaxComparison (builder, loc, elem, reduction, isMax);
930998
931999 // The condition used for the loop is isFirst || <the condition above>.
9321000 isFirst = builder.create <fir::ConvertOp>(loc, cmp.getType (), isFirst);
@@ -1055,11 +1123,19 @@ class OptimizedBufferizationPass
10551123 patterns.insert <ElementalAssignBufferization>(context);
10561124 patterns.insert <BroadcastAssignBufferization>(context);
10571125 patterns.insert <VariableAssignBufferization>(context);
1058- patterns.insert <ReductionElementalConversion<hlfir::CountOp>>(context);
1059- patterns.insert <ReductionElementalConversion<hlfir::AnyOp>>(context);
1060- patterns.insert <ReductionElementalConversion<hlfir::AllOp>>(context);
1061- patterns.insert <MinMaxlocElementalConversion<hlfir::MinlocOp>>(context);
1062- patterns.insert <MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context);
1126+ patterns.insert <ReductionConversion<hlfir::CountOp>>(context);
1127+ patterns.insert <ReductionConversion<hlfir::AnyOp>>(context);
1128+ patterns.insert <ReductionConversion<hlfir::AllOp>>(context);
1129+ // TODO: implement basic minloc/maxloc conversion.
1130+ // patterns.insert<ReductionConversion<hlfir::MaxlocOp>>(context);
1131+ // patterns.insert<ReductionConversion<hlfir::MinlocOp>>(context);
1132+ patterns.insert <ReductionConversion<hlfir::MaxvalOp>>(context);
1133+ patterns.insert <ReductionConversion<hlfir::MinvalOp>>(context);
1134+ patterns.insert <ReductionMaskConversion<hlfir::MinlocOp>>(context);
1135+ patterns.insert <ReductionMaskConversion<hlfir::MaxlocOp>>(context);
1136+ // TODO: implement masked minval/maxval conversion.
1137+ // patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
1138+ // patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);
10631139
10641140 if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
10651141 getOperation (), std::move (patterns), config))) {
0 commit comments