diff --git a/llvm/include/llvm/Analysis/IVDescriptors.h b/llvm/include/llvm/Analysis/IVDescriptors.h index 654a5f10cea96..0169b98ada537 100644 --- a/llvm/include/llvm/Analysis/IVDescriptors.h +++ b/llvm/include/llvm/Analysis/IVDescriptors.h @@ -155,14 +155,6 @@ class RecurrenceDescriptor { LLVM_ABI static bool areAllUsesIn(Instruction *I, SmallPtrSetImpl &Set); - /// Returns a struct describing if the instruction is a llvm.(s/u)(min/max), - /// llvm.minnum/maxnum or a Select(ICmp(X, Y), X, Y) pair of instructions - /// corresponding to a min(X, Y) or max(X, Y), matching the recurrence kind \p - /// Kind. \p Prev specifies the description of an already processed select - /// instruction, so its corresponding cmp can be matched to it. - LLVM_ABI static InstDesc isMinMaxPattern(Instruction *I, RecurKind Kind, - const InstDesc &Prev); - /// Returns a struct describing whether the instruction is either a /// Select(ICmp(A, B), X, Y), or /// Select(FCmp(A, B), X, Y) diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp index 9f8ac6e8e2e0b..a460acd283e91 100644 --- a/llvm/lib/Analysis/IVDescriptors.cpp +++ b/llvm/lib/Analysis/IVDescriptors.cpp @@ -214,6 +214,173 @@ static bool checkOrderedReduction(RecurKind Kind, Instruction *ExactFPMathInst, return true; } +// Helper to collect FMF from a value and its associated fcmp in select patterns +static FastMathFlags collectMinMaxFMF(Value *V) { + FastMathFlags FMF = cast(V)->getFastMathFlags(); + if (auto *Sel = dyn_cast(V)) { + // Accept FMF on either fcmp or select of a min/max idiom. + // TODO: This is a hack to work-around the fact that FMF may not be + // assigned/propagated correctly. If that problem is fixed or we + // standardize on fmin/fmax via intrinsics, this can be removed. + if (auto *FCmp = dyn_cast(Sel->getCondition())) + FMF |= FCmp->getFastMathFlags(); + } + return FMF; +} + +static std::optional +hasRequiredFastMathFlags(FPMathOperator *FPOp, RecurKind &RK, + FastMathFlags FuncFMF) { + bool HasRequiredFMF = + (FuncFMF.noNaNs() && FuncFMF.noSignedZeros()) || + (FPOp && FPOp->hasNoNaNs() && FPOp->hasNoSignedZeros()) || + RK == RecurKind::FMinimum || RK == RecurKind::FMaximum || + RK == RecurKind::FMinimumNum || RK == RecurKind::FMaximumNum; + if (!HasRequiredFMF) { + if (RK == RecurKind::FMax && + match(FPOp, m_Intrinsic(m_Value(), m_Value()))) + RK = RecurKind::FMaxNum; + else if (RK == RecurKind::FMin && + match(FPOp, m_Intrinsic(m_Value(), m_Value()))) + RK = RecurKind::FMinNum; + else + return std::nullopt; + } + return {collectMinMaxFMF(FPOp)}; +} + +static std::optional +getMultiUseMinMax(PHINode *Phi, Loop *TheLoop, FastMathFlags FuncFMF, + ScalarEvolution *SE) { + if (Phi->getNumIncomingValues() != 2 || + Phi->getParent() != TheLoop->getHeader()) + return std::nullopt; + + Type *Ty = Phi->getType(); + BasicBlock *Latch = TheLoop->getLoopLatch(); + if ((!Ty->isIntegerTy() && !Ty->isFloatingPointTy()) || !Latch) + return std::nullopt; + + auto Matches = [](Value *V, Value *&A, Value *&B) -> RecurKind { + if (match(V, m_UMin(m_Value(A), m_Value(B)))) + return RecurKind::UMin; + if (match(V, m_UMax(m_Value(A), m_Value(B)))) + return RecurKind::UMax; + if (match(V, m_SMax(m_Value(A), m_Value(B)))) + return RecurKind::SMax; + if (match(V, m_SMin(m_Value(A), m_Value(B)))) + return RecurKind::SMin; + if (match(V, m_OrdOrUnordFMin(m_Value(A), m_Value(B))) || + match(V, m_Intrinsic(m_Value(A), m_Value(B)))) + return RecurKind::FMin; + if (match(V, m_OrdOrUnordFMax(m_Value(A), m_Value(B))) || + match(V, m_Intrinsic(m_Value(A), m_Value(B)))) + return RecurKind::FMax; + if (match(V, m_FMinimum(m_Value(A), m_Value(B)))) + return RecurKind::FMinimum; + if (match(V, m_FMaximum(m_Value(A), m_Value(B)))) + return RecurKind::FMaximum; + if (match(V, m_Intrinsic(m_Value(A), m_Value(B)))) + return RecurKind::FMinimumNum; + if (match(V, m_Intrinsic(m_Value(A), m_Value(B)))) + return RecurKind::FMaximumNum; + return RecurKind::None; + }; + + FastMathFlags FMF = FastMathFlags::getFast(); + Value *RdxNext = Phi->getIncomingValueForBlock(Latch); + RecurKind RK = RecurKind::None; + // Identify min/max recurrences by walking the def-use chains upwards, + // starting at RdxNext. + SmallVector WorkList = {RdxNext}; + SmallPtrSet Chain = {Phi}; + while (!WorkList.empty()) { + Value *Cur = WorkList.pop_back_val(); + if (!Chain.insert(Cur).second) + continue; + auto *I = dyn_cast(Cur); + if (!I || !TheLoop->contains(I)) + return std::nullopt; + if (auto *PN = dyn_cast(I)) { + if (PN != Phi) + append_range(WorkList, PN->operands()); + continue; + } + Value *A, *B; + RecurKind CurRK = Matches(Cur, A, B); + if (CurRK == RecurKind::None || (RK != RecurKind::None && CurRK != RK)) + return std::nullopt; + + RK = CurRK; + // For floating point recurrences, check we have the required fast-math + // flags. + if (RecurrenceDescriptor::isFPMinMaxRecurrenceKind(CurRK)) { + if (auto CurFMF = + hasRequiredFastMathFlags(cast(Cur), RK, FuncFMF)) + FMF &= *CurFMF; + else + return std::nullopt; + } + + Chain.insert(I); + if (auto *SI = dyn_cast(I)) + Chain.insert(SI->getCondition()); + + if (A == Phi || B == Phi) + continue; + + // Add operand to worklist if it matches the pattern - exactly one must + // match + Value *X, *Y; + auto *IA = dyn_cast(A); + auto *IB = dyn_cast(B); + bool AMatches = IA && TheLoop->contains(IA) && Matches(A, X, Y) == RK; + bool BMatches = IB && TheLoop->contains(IB) && Matches(B, X, Y) == RK; + if (AMatches == BMatches) // Both or neither match + return std::nullopt; + WorkList.push_back(AMatches ? A : B); + } + + // Check users of RdxNext. It can have + // * a single user outside the loop, + // * used stores to the same invariant address, + // * used by the starting recurrence phis. + unsigned IncOut = 0; + StoreInst *IntermediateStore = nullptr; + for (Use &U : RdxNext->uses()) { + auto *User = cast(U.getUser()); + if (!TheLoop->contains(User->getParent())) { + if (++IncOut > 1) + return std::nullopt; + } else if (auto *SI = dyn_cast(User)) { + const SCEV *Ptr = SE->getSCEV(SI->getPointerOperand()); + if (U.getOperandNo() == SI->getPointerOperandIndex() || + !SE->isLoopInvariant(Ptr, TheLoop) || + (IntermediateStore && + SE->getSCEV(IntermediateStore->getPointerOperand()) != Ptr)) + return std::nullopt; + // Keep the store that appears last in the block, as it will be the final + // reduction value. + if (!IntermediateStore || IntermediateStore->comesBefore(SI)) + IntermediateStore = SI; + } else if (Phi != User) + return std::nullopt; + } + + // All ops on the chain from Phi to RdxNext must only be used by instructions + // in the chain. + for (Value *Op : Chain) + if (Op != RdxNext && + any_of(Op->users(), [&Chain](User *U) { return !Chain.contains(U); })) + return std::nullopt; + + SmallPtrSet Casts; + return RecurrenceDescriptor( + Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader()), + cast(RdxNext), IntermediateStore, RK, FMF, nullptr, + Phi->getType(), false, false, Casts, -1U); +} + bool RecurrenceDescriptor::AddReductionVar( PHINode *Phi, RecurKind Kind, Loop *TheLoop, FastMathFlags FuncFMF, RecurrenceDescriptor &RedDes, DemandedBits *DB, AssumptionCache *AC, @@ -249,9 +416,8 @@ bool RecurrenceDescriptor::AddReductionVar( // must include the original PHI. bool FoundStartPHI = false; - // To recognize min/max patterns formed by a icmp select sequence, we store - // the number of instruction we saw from the recognized min/max pattern, - // to make sure we only see exactly the two instructions. + // To recognize AnyOf patterns formed by a icmp select sequence, we store + // the number of instruction we saw to make sure we only see one. unsigned NumCmpSelectPatternInst = 0; InstDesc ReduxDesc(false, nullptr); @@ -276,8 +442,7 @@ bool RecurrenceDescriptor::AddReductionVar( } else if (RecurrenceType->isIntegerTy()) { if (!isIntegerRecurrenceKind(Kind)) return false; - if (!isMinMaxRecurrenceKind(Kind)) - Start = lookThroughAnd(Phi, RecurrenceType, VisitedInsts, CastInsts); + Start = lookThroughAnd(Phi, RecurrenceType, VisitedInsts, CastInsts); } else { // Pointer min/max may exist, but it is not supported as a reduction op. return false; @@ -384,18 +549,8 @@ bool RecurrenceDescriptor::AddReductionVar( if (!ReduxDesc.isRecurrence()) return false; // FIXME: FMF is allowed on phi, but propagation is not handled correctly. - if (isa(ReduxDesc.getPatternInst()) && !IsAPhi) { - FastMathFlags CurFMF = ReduxDesc.getPatternInst()->getFastMathFlags(); - if (auto *Sel = dyn_cast(ReduxDesc.getPatternInst())) { - // Accept FMF on either fcmp or select of a min/max idiom. - // TODO: This is a hack to work-around the fact that FMF may not be - // assigned/propagated correctly. If that problem is fixed or we - // standardize on fmin/fmax via intrinsics, this can be removed. - if (auto *FCmp = dyn_cast(Sel->getCondition())) - CurFMF |= FCmp->getFastMathFlags(); - } - FMF &= CurFMF; - } + if (isa(ReduxDesc.getPatternInst()) && !IsAPhi) + FMF &= collectMinMaxFMF(ReduxDesc.getPatternInst()); // Update this reduction kind if we matched a new instruction. // TODO: Can we eliminate the need for a 2nd InstDesc by keeping 'Kind' // state accurate while processing the worklist? @@ -412,18 +567,14 @@ bool RecurrenceDescriptor::AddReductionVar( return false; // A reduction operation must only have one use of the reduction value. - if (!IsAPhi && !IsASelect && !isMinMaxRecurrenceKind(Kind) && - !isAnyOfRecurrenceKind(Kind) && hasMultipleUsesOf(Cur, VisitedInsts, 1)) + if (!IsAPhi && !IsASelect && !isAnyOfRecurrenceKind(Kind) && + hasMultipleUsesOf(Cur, VisitedInsts, 1)) return false; // All inputs to a PHI node must be a reduction value. if (IsAPhi && Cur != Phi && !areAllUsesIn(Cur, VisitedInsts)) return false; - if (isIntMinMaxRecurrenceKind(Kind) && (isa(Cur) || IsASelect)) - ++NumCmpSelectPatternInst; - if (isFPMinMaxRecurrenceKind(Kind) && (isa(Cur) || IsASelect)) - ++NumCmpSelectPatternInst; if (isAnyOfRecurrenceKind(Kind) && IsASelect) ++NumCmpSelectPatternInst; @@ -470,7 +621,7 @@ bool RecurrenceDescriptor::AddReductionVar( } // Process instructions only once (termination). Each reduction cycle - // value must only be used once, except by phi nodes and min/max + // value must only be used once, except by phi nodes and conditional // reductions which are represented as a cmp followed by a select. InstDesc IgnoredVal(false, nullptr); if (VisitedInsts.insert(UI).second) { @@ -486,12 +637,9 @@ bool RecurrenceDescriptor::AddReductionVar( NonPHIs.push_back(UI); } } else if (!isa(UI) && - ((!isa(UI) && !isa(UI) && - !isa(UI)) || - (!isConditionalRdxPattern(UI).isRecurrence() && + ((!isConditionalRdxPattern(UI).isRecurrence() && !isAnyOfPattern(TheLoop, Phi, UI, IgnoredVal) - .isRecurrence() && - !isMinMaxPattern(UI, Kind, IgnoredVal).isRecurrence()))) + .isRecurrence()))) return false; // Remember that we completed the cycle. @@ -502,13 +650,6 @@ bool RecurrenceDescriptor::AddReductionVar( Worklist.append(NonPHIs.begin(), NonPHIs.end()); } - // This means we have seen one but not the other instruction of the - // pattern or more than just a select and cmp. Zero implies that we saw a - // llvm.min/max intrinsic, which is always OK. - if (isMinMaxRecurrenceKind(Kind) && NumCmpSelectPatternInst != 2 && - NumCmpSelectPatternInst != 0) - return false; - if (isAnyOfRecurrenceKind(Kind) && NumCmpSelectPatternInst != 1) return false; @@ -789,55 +930,6 @@ RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop, return InstDesc(false, I); } -RecurrenceDescriptor::InstDesc -RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind, - const InstDesc &Prev) { - assert((isa(I) || isa(I) || isa(I)) && - "Expected a cmp or select or call instruction"); - if (!isMinMaxRecurrenceKind(Kind)) - return InstDesc(false, I); - - // We must handle the select(cmp()) as a single instruction. Advance to the - // select. - if (match(I, m_OneUse(m_Cmp()))) { - if (auto *Select = dyn_cast(*I->user_begin())) - return InstDesc(Select, Prev.getRecKind()); - } - - // Only match select with single use cmp condition, or a min/max intrinsic. - if (!isa(I) && - !match(I, m_Select(m_OneUse(m_Cmp()), m_Value(), m_Value()))) - return InstDesc(false, I); - - // Look for a min/max pattern. - if (match(I, m_UMin(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::UMin, I); - if (match(I, m_UMax(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::UMax, I); - if (match(I, m_SMax(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::SMax, I); - if (match(I, m_SMin(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::SMin, I); - if (match(I, m_OrdOrUnordFMin(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::FMin, I); - if (match(I, m_OrdOrUnordFMax(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::FMax, I); - if (match(I, m_FMinNum(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::FMin, I); - if (match(I, m_FMaxNum(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::FMax, I); - if (match(I, m_FMinimumNum(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::FMinimumNum, I); - if (match(I, m_FMaximumNum(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::FMaximumNum, I); - if (match(I, m_FMinimum(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::FMinimum, I); - if (match(I, m_FMaximum(m_Value(), m_Value()))) - return InstDesc(Kind == RecurKind::FMaximum, I); - - return InstDesc(false, I); -} - /// Returns true if the select instruction has users in the compare-and-add /// reduction pattern below. The select instruction argument is the last one /// in the sequence. @@ -928,43 +1020,6 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr( case Instruction::Call: if (isAnyOfRecurrenceKind(Kind)) return isAnyOfPattern(L, OrigPhi, I, Prev); - auto HasRequiredFMF = [&]() { - if (FuncFMF.noNaNs() && FuncFMF.noSignedZeros()) - return true; - if (isa(I) && I->hasNoNaNs() && I->hasNoSignedZeros()) - return true; - // minimum/minnum and maximum/maxnum intrinsics do not require nsz and nnan - // flags since NaN and signed zeroes are propagated in the intrinsic - // implementation. - return match(I, m_Intrinsic(m_Value(), m_Value())) || - match(I, m_Intrinsic(m_Value(), m_Value())) || - match(I, - m_Intrinsic(m_Value(), m_Value())) || - match(I, m_Intrinsic(m_Value(), m_Value())); - }; - if (isIntMinMaxRecurrenceKind(Kind)) - return isMinMaxPattern(I, Kind, Prev); - if (isFPMinMaxRecurrenceKind(Kind)) { - InstDesc Res = isMinMaxPattern(I, Kind, Prev); - if (!Res.isRecurrence()) - return InstDesc(false, I); - if (HasRequiredFMF()) - return Res; - // We may be able to vectorize FMax/FMin reductions using maxnum/minnum - // intrinsics with extra checks ensuring the vector loop handles only - // non-NaN inputs. - if (match(I, m_Intrinsic(m_Value(), m_Value()))) { - assert(Kind == RecurKind::FMax && - "unexpected recurrence kind for maxnum"); - return InstDesc(I, RecurKind::FMaxNum); - } - if (match(I, m_Intrinsic(m_Value(), m_Value()))) { - assert(Kind == RecurKind::FMin && - "unexpected recurrence kind for minnum"); - return InstDesc(I, RecurKind::FMinNum); - } - return InstDesc(false, I); - } if (isFMulAddIntrinsic(I)) return InstDesc(Kind == RecurKind::FMulAdd, I, I->hasAllowReassoc() ? nullptr : I); @@ -1035,24 +1090,9 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::SMax, TheLoop, FMF, RedDes, DB, AC, DT, - SE)) { - LLVM_DEBUG(dbgs() << "Found a SMAX reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RecurKind::SMin, TheLoop, FMF, RedDes, DB, AC, DT, - SE)) { - LLVM_DEBUG(dbgs() << "Found a SMIN reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RecurKind::UMax, TheLoop, FMF, RedDes, DB, AC, DT, - SE)) { - LLVM_DEBUG(dbgs() << "Found a UMAX reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RecurKind::UMin, TheLoop, FMF, RedDes, DB, AC, DT, - SE)) { - LLVM_DEBUG(dbgs() << "Found a UMIN reduction PHI." << *Phi << "\n"); + if (auto RD = getMultiUseMinMax(Phi, TheLoop, FMF, SE)) { + LLVM_DEBUG(dbgs() << "Found a min/max reduction PHI." << *Phi << "\n"); + RedDes = *RD; return true; } if (AddReductionVar(Phi, RecurKind::AnyOf, TheLoop, FMF, RedDes, DB, AC, DT, @@ -1081,43 +1121,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT, - SE)) { - LLVM_DEBUG(dbgs() << "Found a float MAX reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RecurKind::FMin, TheLoop, FMF, RedDes, DB, AC, DT, - SE)) { - LLVM_DEBUG(dbgs() << "Found a float MIN reduction PHI." << *Phi << "\n"); - return true; - } if (AddReductionVar(Phi, RecurKind::FMulAdd, TheLoop, FMF, RedDes, DB, AC, DT, SE)) { LLVM_DEBUG(dbgs() << "Found an FMulAdd reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RecurKind::FMaximum, TheLoop, FMF, RedDes, DB, AC, DT, - SE)) { - LLVM_DEBUG(dbgs() << "Found a float MAXIMUM reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RecurKind::FMinimum, TheLoop, FMF, RedDes, DB, AC, DT, - SE)) { - LLVM_DEBUG(dbgs() << "Found a float MINIMUM reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RecurKind::FMaximumNum, TheLoop, FMF, RedDes, DB, AC, - DT, SE)) { - LLVM_DEBUG(dbgs() << "Found a float MAXIMUMNUM reduction PHI." << *Phi - << "\n"); - return true; - } - if (AddReductionVar(Phi, RecurKind::FMinimumNum, TheLoop, FMF, RedDes, DB, AC, - DT, SE)) { - LLVM_DEBUG(dbgs() << "Found a float MINIMUMNUM reduction PHI." << *Phi - << "\n"); - return true; - } // Not a reduction of known type. return false;