@@ -52,6 +52,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5252 case RecurKind::UMin:
5353 case RecurKind::IAnyOf:
5454 case RecurKind::FAnyOf:
55+ case RecurKind::IFindLastIV:
56+ case RecurKind::FFindLastIV:
5557 return true ;
5658 }
5759 return false ;
@@ -373,7 +375,7 @@ bool RecurrenceDescriptor::AddReductionVar(
373375 // type-promoted).
374376 if (Cur != Start) {
375377 ReduxDesc =
376- isRecurrenceInstr (TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF);
378+ isRecurrenceInstr (TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE );
377379 ExactFPMathInst = ExactFPMathInst == nullptr
378380 ? ReduxDesc.getExactFPMathInst ()
379381 : ExactFPMathInst;
@@ -660,6 +662,87 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
660662 : RecurKind::FAnyOf);
661663}
662664
665+ // We are looking for loops that do something like this:
666+ // int r = 0;
667+ // for (int i = 0; i < n; i++) {
668+ // if (src[i] > 3)
669+ // r = i;
670+ // }
671+ // The reduction value (r) is derived from either the values of an increasing
672+ // induction variable (i) sequence, or from the start value (0).
673+ // The LLVM IR generated for such loops would be as follows:
674+ // for.body:
675+ // %r = phi i32 [ %spec.select, %for.body ], [ 0, %entry ]
676+ // %i = phi i32 [ %inc, %for.body ], [ 0, %entry ]
677+ // ...
678+ // %cmp = icmp sgt i32 %5, 3
679+ // %spec.select = select i1 %cmp, i32 %i, i32 %r
680+ // %inc = add nsw i32 %i, 1
681+ // ...
682+ // Since 'i' is an increasing induction variable, the reduction value after the
683+ // loop will be the maximum value of 'i' that the condition (src[i] > 3) is
684+ // satisfied, or the start value (0 in the example above). When the start value
685+ // of the increasing induction variable 'i' is greater than the minimum value of
686+ // the data type, we can use the minimum value of the data type as a sentinel
687+ // value to replace the start value. This allows us to perform a single
688+ // reduction max operation to obtain the final reduction result.
689+ // TODO: It is possible to solve the case where the start value is the minimum
690+ // value of the data type or a non-constant value by using mask and multiple
691+ // reduction operations.
692+ RecurrenceDescriptor::InstDesc
693+ RecurrenceDescriptor::isFindLastIVPattern (Loop *Loop, PHINode *OrigPhi,
694+ Instruction *I, ScalarEvolution *SE) {
695+ // Only match select with single use cmp condition.
696+ // TODO: Only handle single use for now.
697+ CmpInst::Predicate Pred;
698+ if (!match (I, m_Select (m_OneUse (m_Cmp (Pred, m_Value (), m_Value ())), m_Value (),
699+ m_Value ())))
700+ return InstDesc (false , I);
701+
702+ SelectInst *SI = cast<SelectInst>(I);
703+ Value *NonRdxPhi = nullptr ;
704+
705+ if (OrigPhi == dyn_cast<PHINode>(SI->getTrueValue ()))
706+ NonRdxPhi = SI->getFalseValue ();
707+ else if (OrigPhi == dyn_cast<PHINode>(SI->getFalseValue ()))
708+ NonRdxPhi = SI->getTrueValue ();
709+ else
710+ return InstDesc (false , I);
711+
712+ auto IsIncreasingLoopInduction = [&SE, &Loop](Value *V) {
713+ auto *Phi = dyn_cast<PHINode>(V);
714+ if (!Phi)
715+ return false ;
716+
717+ if (!SE)
718+ return false ;
719+
720+ InductionDescriptor ID;
721+ if (!InductionDescriptor::isInductionPHI (Phi, Loop, SE, ID))
722+ return false ;
723+
724+ const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV (Phi));
725+ if (!AR->hasNoSignedWrap ())
726+ return false ;
727+
728+ ConstantInt *IVStartValue = dyn_cast<ConstantInt>(ID.getStartValue ());
729+ if (!IVStartValue || IVStartValue->isMinSignedValue ())
730+ return false ;
731+
732+ const SCEV *Step = ID.getStep ();
733+ return SE->isKnownPositive (Step);
734+ };
735+
736+ // We are looking for selects of the form:
737+ // select(cmp(), phi, loop_induction) or
738+ // select(cmp(), loop_induction, phi)
739+ if (!IsIncreasingLoopInduction (NonRdxPhi))
740+ return InstDesc (false , I);
741+
742+ return InstDesc (I, isa<ICmpInst>(I->getOperand (0 )) ? RecurKind::IFindLastIV
743+ : RecurKind::FFindLastIV);
744+ }
745+
663746RecurrenceDescriptor::InstDesc
664747RecurrenceDescriptor::isMinMaxPattern (Instruction *I, RecurKind Kind,
665748 const InstDesc &Prev) {
@@ -763,10 +846,9 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
763846 return InstDesc (true , SI);
764847}
765848
766- RecurrenceDescriptor::InstDesc
767- RecurrenceDescriptor::isRecurrenceInstr (Loop *L, PHINode *OrigPhi,
768- Instruction *I, RecurKind Kind,
769- InstDesc &Prev, FastMathFlags FuncFMF) {
849+ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr (
850+ Loop *L, PHINode *OrigPhi, Instruction *I, RecurKind Kind, InstDesc &Prev,
851+ FastMathFlags FuncFMF, ScalarEvolution *SE) {
770852 assert (Prev.getRecKind () == RecurKind::None || Prev.getRecKind () == Kind);
771853 switch (I->getOpcode ()) {
772854 default :
@@ -796,6 +878,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
796878 if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
797879 Kind == RecurKind::Add || Kind == RecurKind::Mul)
798880 return isConditionalRdxPattern (Kind, I);
881+ if (isFindLastIVRecurrenceKind (Kind))
882+ return isFindLastIVPattern (L, OrigPhi, I, SE);
799883 [[fallthrough]];
800884 case Instruction::FCmp:
801885 case Instruction::ICmp:
@@ -900,6 +984,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
900984 << *Phi << " \n " );
901985 return true ;
902986 }
987+ if (AddReductionVar (Phi, RecurKind::IFindLastIV, TheLoop, FMF, RedDes, DB, AC,
988+ DT, SE)) {
989+ LLVM_DEBUG (dbgs () << " Found a FindLastIV reduction PHI." << *Phi << " \n " );
990+ return true ;
991+ }
903992 if (AddReductionVar (Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
904993 SE)) {
905994 LLVM_DEBUG (dbgs () << " Found an FMult reduction PHI." << *Phi << " \n " );
@@ -1089,6 +1178,9 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
10891178 case RecurKind::FAnyOf:
10901179 return getRecurrenceStartValue ();
10911180 break ;
1181+ case RecurKind::IFindLastIV:
1182+ case RecurKind::FFindLastIV:
1183+ return getRecurrenceIdentity (RecurKind::SMax, Tp, FMF);
10921184 default :
10931185 llvm_unreachable (" Unknown recurrence kind" );
10941186 }
@@ -1116,12 +1208,14 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
11161208 case RecurKind::UMax:
11171209 case RecurKind::UMin:
11181210 case RecurKind::IAnyOf:
1211+ case RecurKind::IFindLastIV:
11191212 return Instruction::ICmp;
11201213 case RecurKind::FMax:
11211214 case RecurKind::FMin:
11221215 case RecurKind::FMaximum:
11231216 case RecurKind::FMinimum:
11241217 case RecurKind::FAnyOf:
1218+ case RecurKind::FFindLastIV:
11251219 return Instruction::FCmp;
11261220 default :
11271221 llvm_unreachable (" Unknown recurrence operation" );
0 commit comments