@@ -604,6 +604,20 @@ Value *VPInstruction::generate(VPTransformState &State) {
604604 return Builder.CreateVectorSplat (
605605 State.VF , State.get (getOperand (0 ), /* IsScalar*/ true ), " broadcast" );
606606 }
607+ case VPInstruction::ReductionStartVector: {
608+ if (State.VF .isScalar ())
609+ return State.get (getOperand (0 ), true );
610+ IRBuilderBase::FastMathFlagGuard FMFG (Builder);
611+ Builder.setFastMathFlags (getFastMathFlags ());
612+ // If this start vector is scaled then it should produce a vector with fewer
613+ // elements than the VF.
614+ ElementCount VF = State.VF .divideCoefficientBy (
615+ cast<ConstantInt>(getOperand (2 )->getLiveInIRValue ())->getZExtValue ());
616+ auto *Iden = Builder.CreateVectorSplat (VF, State.get (getOperand (1 ), true ));
617+ Constant *Zero = Builder.getInt32 (0 );
618+ return Builder.CreateInsertElement (Iden, State.get (getOperand (0 ), true ),
619+ Zero);
620+ }
607621 case VPInstruction::ComputeAnyOfResult: {
608622 // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
609623 // and will be removed by breaking up the recipe further.
@@ -882,6 +896,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
882896 case VPInstruction::PtrAdd:
883897 case VPInstruction::WideIVStep:
884898 case VPInstruction::StepVector:
899+ case VPInstruction::ReductionStartVector:
885900 return false ;
886901 default :
887902 return true ;
@@ -912,6 +927,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
912927 case VPInstruction::CanonicalIVIncrementForPart:
913928 case VPInstruction::BranchOnCount:
914929 case VPInstruction::BranchOnCond:
930+ case VPInstruction::ReductionStartVector:
915931 return true ;
916932 case VPInstruction::PtrAdd:
917933 return Op == getOperand (0 ) || vputils::onlyFirstLaneUsed (this );
@@ -1017,6 +1033,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10171033 case VPInstruction::FirstActiveLane:
10181034 O << " first-active-lane" ;
10191035 break ;
1036+ case VPInstruction::ReductionStartVector:
1037+ O << " reduction-start-vector" ;
1038+ break ;
10201039 default :
10211040 O << Instruction::getOpcodeName (getOpcode ());
10221041 }
@@ -1608,6 +1627,7 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
16081627 Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
16091628 Opcode == Instruction::FCmp || Opcode == Instruction::Select ||
16101629 Opcode == VPInstruction::WideIVStep ||
1630+ Opcode == VPInstruction::ReductionStartVector ||
16111631 Opcode == VPInstruction::ComputeReductionResult;
16121632 case OperationType::NonNegOp:
16131633 return Opcode == Instruction::ZExt;
@@ -3838,17 +3858,19 @@ void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent,
38383858#endif
38393859
38403860void VPReductionPHIRecipe::execute (VPTransformState &State) {
3841- // If this phi is fed by a scaled reduction then it should output a
3842- // vector with fewer elements than the VF .
3843- ElementCount VF = State. VF . divideCoefficientBy (VFScaleFactor );
3861+ // Reductions do not have to start at zero. They can start with
3862+ // any loop invariant values .
3863+ VPValue *StartVPV = getStartValue ( );
38443864
38453865 // In order to support recurrences we need to be able to vectorize Phi nodes.
38463866 // Phi nodes have cycles, so we need to vectorize them in two stages. This is
38473867 // stage #1: We create a new vector PHI node with no incoming edges. We'll use
38483868 // this value when we vectorize all of the instructions that use the PHI.
3849- auto *ScalarTy = State.TypeAnalysis .inferScalarType (this );
3869+ BasicBlock *VectorPH =
3870+ State.CFG .VPBB2IRBB .at (getParent ()->getCFGPredecessor (0 ));
38503871 bool ScalarPHI = State.VF .isScalar () || IsInLoop;
3851- Type *VecTy = ScalarPHI ? ScalarTy : VectorType::get (ScalarTy, VF);
3872+ Value *StartV = State.get (StartVPV, ScalarPHI);
3873+ Type *VecTy = StartV->getType ();
38523874
38533875 BasicBlock *HeaderBB = State.CFG .PrevBB ;
38543876 assert (State.CurrentParentLoop ->getHeader () == HeaderBB &&
@@ -3857,49 +3879,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
38573879 Phi->insertBefore (HeaderBB->getFirstInsertionPt ());
38583880 State.set (this , Phi, IsInLoop);
38593881
3860- BasicBlock *VectorPH =
3861- State.CFG .VPBB2IRBB .at (getParent ()->getCFGPredecessor (0 ));
3862- // Create start and identity vector values for the reduction in the preheader.
3863- // TODO: Introduce recipes in VPlan preheader to create initial values.
3864- IRBuilderBase::InsertPointGuard IPBuilder (State.Builder );
3865- State.Builder .SetInsertPoint (VectorPH->getTerminator ());
3866-
3867- // Reductions do not have to start at zero. They can start with
3868- // any loop invariant values.
3869- VPValue *StartVPV = getStartValue ();
3870- RecurKind RK = RdxDesc.getRecurrenceKind ();
3871- if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RK) ||
3872- RecurrenceDescriptor::isAnyOfRecurrenceKind (RK) ||
3873- RecurrenceDescriptor::isFindLastIVRecurrenceKind (RK)) {
3874- // [I|F]FindLastIV will use a sentinel value to initialize the reduction
3875- // phi or the resume value from the main vector loop when vectorizing the
3876- // epilogue loop. In the exit block, ComputeReductionResult will generate
3877- // checks to verify if the reduction result is the sentinel value. If the
3878- // result is the sentinel value, it will be corrected back to the start
3879- // value.
3880- // TODO: The sentinel value is not always necessary. When the start value is
3881- // a constant, and smaller than the start value of the induction variable,
3882- // the start value can be directly used to initialize the reduction phi.
3883- Phi->addIncoming (State.get (StartVPV, ScalarPHI), VectorPH);
3884- return ;
3885- }
3886-
3887- Value *Iden = getRecurrenceIdentity (RK, VecTy->getScalarType (),
3888- RdxDesc.getFastMathFlags ());
3889- unsigned CurrentPart = getUnrollPart (*this );
3890- Value *StartV = StartVPV->getLiveInIRValue ();
3891- if (!ScalarPHI) {
3892- if (CurrentPart == 0 ) {
3893- Iden = State.Builder .CreateVectorSplat (VF, Iden);
3894- Constant *Zero = State.Builder .getInt32 (0 );
3895- StartV = State.Builder .CreateInsertElement (Iden, StartV, Zero);
3896- } else {
3897- Iden = State.Builder .CreateVectorSplat (VF, Iden);
3898- }
3899- }
3900-
3901- Value *StartVal = (CurrentPart == 0 ) ? StartV : Iden;
3902- Phi->addIncoming (StartVal, VectorPH);
3882+ Phi->addIncoming (StartV, VectorPH);
39033883}
39043884
39053885#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments