diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h index 8f4c0c88336ac..1818ee03d2ec8 100644 --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -427,11 +427,6 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src, Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, const RecurrenceDescriptor &Desc); -/// Create a generic reduction using a recurrence descriptor \p Desc -/// Fast-math-flags are propagated using the RecurrenceDescriptor. -Value *createReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc, - Value *Src, PHINode *OrigPhi = nullptr); - /// Create an ordered reduction intrinsic using the given recurrence /// descriptor \p Desc. Value *createOrderedReduction(IRBuilderBase &B, diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 84c08556f8a25..185af8631454a 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1336,7 +1336,8 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src, const RecurrenceDescriptor &Desc) { RecurKind Kind = Desc.getRecurrenceKind(); assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) && - "AnyOf reduction is not supported."); + !RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) && + "AnyOf or FindLastIV reductions are not supported."); Intrinsic::ID Id = getReductionIntrinsicID(Kind); auto *SrcTy = cast(Src->getType()); Type *SrcEltTy = SrcTy->getElementType(); @@ -1345,24 +1346,6 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src, return VBuilder.createSimpleReduction(Id, SrcTy, Ops); } -Value *llvm::createReduction(IRBuilderBase &B, - const RecurrenceDescriptor &Desc, Value *Src, - PHINode *OrigPhi) { - // TODO: Support in-order reductions based on the recurrence descriptor. - // All ops in the reduction inherit fast-math-flags from the recurrence - // descriptor. - IRBuilderBase::FastMathFlagGuard FMFGuard(B); - B.setFastMathFlags(Desc.getFastMathFlags()); - - RecurKind RK = Desc.getRecurrenceKind(); - if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) - return createAnyOfReduction(B, Src, Desc, OrigPhi); - if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) - return createFindLastIVReduction(B, Src, Desc); - - return createSimpleReduction(B, Src, RK); -} - Value *llvm::createOrderedReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc, Value *Src, Value *Start) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 6e396eda6aac6..d00e2a6e84908 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -666,8 +666,21 @@ Value *VPInstruction::generate(VPTransformState &State) { RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) || RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) && !PhiR->isInLoop()) { - ReducedPartRdx = - createReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi); + // TODO: Support in-order reductions based on the recurrence descriptor. + // All ops in the reduction inherit fast-math-flags from the recurrence + // descriptor. + IRBuilderBase::FastMathFlagGuard FMFG(Builder); + Builder.setFastMathFlags(RdxDesc.getFastMathFlags()); + + if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) + ReducedPartRdx = + createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi); + else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) + ReducedPartRdx = + createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc); + else + ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK); + // If the reduction can be performed in a smaller type, we need to extend // the reduction to the wider type before we branch to the original loop. if (PhiTy != RdxDesc.getRecurrenceType()) @@ -2297,7 +2310,7 @@ void VPReductionRecipe::execute(VPTransformState &State) { NextInChain = NewRed; } else { PrevInChain = State.get(getChainOp(), /*IsScalar*/ true); - NewRed = createReduction(State.Builder, RdxDesc, NewVecOp); + NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind); if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(), NewRed, PrevInChain);