@@ -604,6 +604,20 @@ Value *VPInstruction::generate(VPTransformState &State) {
604604 return Builder.CreateVectorSplat (
605605 State.VF , State.get (getOperand (0 ), /* IsScalar*/ true ), " broadcast" );
606606 }
607+ case VPInstruction::ReductionStartVector: {
608+ if (State.VF .isScalar ())
609+ return State.get (getOperand (0 ), true );
610+ IRBuilderBase::FastMathFlagGuard FMFG (Builder);
611+ Builder.setFastMathFlags (getFastMathFlags ());
612+ // If this start vector is scaled then it should produce a vector with fewer
613+ // elements than the VF.
614+ ElementCount VF = State.VF .divideCoefficientBy (
615+ cast<ConstantInt>(getOperand (2 )->getLiveInIRValue ())->getZExtValue ());
616+ auto *Iden = Builder.CreateVectorSplat (VF, State.get (getOperand (1 ), true ));
617+ Constant *Zero = Builder.getInt32 (0 );
618+ return Builder.CreateInsertElement (Iden, State.get (getOperand (0 ), true ),
619+ Zero);
620+ }
607621 case VPInstruction::ComputeAnyOfResult: {
608622 // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
609623 // and will be removed by breaking up the recipe further.
@@ -899,6 +913,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
899913 case VPInstruction::PtrAdd:
900914 case VPInstruction::WideIVStep:
901915 case VPInstruction::StepVector:
916+ case VPInstruction::ReductionStartVector:
902917 return false ;
903918 default :
904919 return true ;
@@ -929,6 +944,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
929944 case VPInstruction::CanonicalIVIncrementForPart:
930945 case VPInstruction::BranchOnCount:
931946 case VPInstruction::BranchOnCond:
947+ case VPInstruction::ReductionStartVector:
932948 return true ;
933949 case VPInstruction::PtrAdd:
934950 return Op == getOperand (0 ) || vputils::onlyFirstLaneUsed (this );
@@ -1034,6 +1050,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10341050 case VPInstruction::FirstActiveLane:
10351051 O << " first-active-lane" ;
10361052 break ;
1053+ case VPInstruction::ReductionStartVector:
1054+ O << " reduction-start-vector" ;
1055+ break ;
10371056 default :
10381057 O << Instruction::getOpcodeName (getOpcode ());
10391058 }
@@ -1617,6 +1636,7 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
16171636 Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
16181637 Opcode == Instruction::FCmp || Opcode == Instruction::Select ||
16191638 Opcode == VPInstruction::WideIVStep ||
1639+ Opcode == VPInstruction::ReductionStartVector ||
16201640 Opcode == VPInstruction::ComputeReductionResult;
16211641 case OperationType::NonNegOp:
16221642 return Opcode == Instruction::ZExt;
@@ -3847,17 +3867,19 @@ void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent,
38473867#endif
38483868
38493869void VPReductionPHIRecipe::execute (VPTransformState &State) {
3850- // If this phi is fed by a scaled reduction then it should output a
3851- // vector with fewer elements than the VF .
3852- ElementCount VF = State. VF . divideCoefficientBy (VFScaleFactor );
3870+ // Reductions do not have to start at zero. They can start with
3871+ // any loop invariant values .
3872+ VPValue *StartVPV = getStartValue ( );
38533873
38543874 // In order to support recurrences we need to be able to vectorize Phi nodes.
38553875 // Phi nodes have cycles, so we need to vectorize them in two stages. This is
38563876 // stage #1: We create a new vector PHI node with no incoming edges. We'll use
38573877 // this value when we vectorize all of the instructions that use the PHI.
3858- auto *ScalarTy = State.TypeAnalysis .inferScalarType (this );
3878+ BasicBlock *VectorPH =
3879+ State.CFG .VPBB2IRBB .at (getParent ()->getCFGPredecessor (0 ));
38593880 bool ScalarPHI = State.VF .isScalar () || IsInLoop;
3860- Type *VecTy = ScalarPHI ? ScalarTy : VectorType::get (ScalarTy, VF);
3881+ Value *StartV = State.get (StartVPV, ScalarPHI);
3882+ Type *VecTy = StartV->getType ();
38613883
38623884 BasicBlock *HeaderBB = State.CFG .PrevBB ;
38633885 assert (State.CurrentParentLoop ->getHeader () == HeaderBB &&
@@ -3866,49 +3888,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
38663888 Phi->insertBefore (HeaderBB->getFirstInsertionPt ());
38673889 State.set (this , Phi, IsInLoop);
38683890
3869- BasicBlock *VectorPH =
3870- State.CFG .VPBB2IRBB .at (getParent ()->getCFGPredecessor (0 ));
3871- // Create start and identity vector values for the reduction in the preheader.
3872- // TODO: Introduce recipes in VPlan preheader to create initial values.
3873- IRBuilderBase::InsertPointGuard IPBuilder (State.Builder );
3874- State.Builder .SetInsertPoint (VectorPH->getTerminator ());
3875-
3876- // Reductions do not have to start at zero. They can start with
3877- // any loop invariant values.
3878- VPValue *StartVPV = getStartValue ();
3879- RecurKind RK = RdxDesc.getRecurrenceKind ();
3880- if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RK) ||
3881- RecurrenceDescriptor::isAnyOfRecurrenceKind (RK) ||
3882- RecurrenceDescriptor::isFindLastIVRecurrenceKind (RK)) {
3883- // [I|F]FindLastIV will use a sentinel value to initialize the reduction
3884- // phi or the resume value from the main vector loop when vectorizing the
3885- // epilogue loop. In the exit block, ComputeReductionResult will generate
3886- // checks to verify if the reduction result is the sentinel value. If the
3887- // result is the sentinel value, it will be corrected back to the start
3888- // value.
3889- // TODO: The sentinel value is not always necessary. When the start value is
3890- // a constant, and smaller than the start value of the induction variable,
3891- // the start value can be directly used to initialize the reduction phi.
3892- Phi->addIncoming (State.get (StartVPV, ScalarPHI), VectorPH);
3893- return ;
3894- }
3895-
3896- Value *Iden = getRecurrenceIdentity (RK, VecTy->getScalarType (),
3897- RdxDesc.getFastMathFlags ());
3898- unsigned CurrentPart = getUnrollPart (*this );
3899- Value *StartV = StartVPV->getLiveInIRValue ();
3900- if (!ScalarPHI) {
3901- if (CurrentPart == 0 ) {
3902- Iden = State.Builder .CreateVectorSplat (VF, Iden);
3903- Constant *Zero = State.Builder .getInt32 (0 );
3904- StartV = State.Builder .CreateInsertElement (Iden, StartV, Zero);
3905- } else {
3906- Iden = State.Builder .CreateVectorSplat (VF, Iden);
3907- }
3908- }
3909-
3910- Value *StartVal = (CurrentPart == 0 ) ? StartV : Iden;
3911- Phi->addIncoming (StartVal, VectorPH);
3891+ Phi->addIncoming (StartV, VectorPH);
39123892}
39133893
39143894#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments