@@ -604,6 +604,20 @@ Value *VPInstruction::generate(VPTransformState &State) {
604604 return Builder.CreateVectorSplat (
605605 State.VF , State.get (getOperand (0 ), /* IsScalar*/ true ), " broadcast" );
606606 }
607+ case VPInstruction::ReductionStartVector: {
608+ if (State.VF .isScalar ())
609+ return State.get (getOperand (0 ), true );
610+ IRBuilderBase::FastMathFlagGuard FMFG (Builder);
611+ Builder.setFastMathFlags (getFastMathFlags ());
612+ // If this start vector is scaled then it should produce a vector with fewer
613+ // elements than the VF.
614+ ElementCount VF = State.VF .divideCoefficientBy (
615+ cast<ConstantInt>(getOperand (2 )->getLiveInIRValue ())->getZExtValue ());
616+ auto *Iden = Builder.CreateVectorSplat (VF, State.get (getOperand (1 ), true ));
617+ Constant *Zero = Builder.getInt32 (0 );
618+ return Builder.CreateInsertElement (Iden, State.get (getOperand (0 ), true ),
619+ Zero);
620+ }
607621 case VPInstruction::ComputeAnyOfResult: {
608622 // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
609623 // and will be removed by breaking up the recipe further.
@@ -900,6 +914,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
900914 case VPInstruction::PtrAdd:
901915 case VPInstruction::WideIVStep:
902916 case VPInstruction::StepVector:
917+ case VPInstruction::ReductionStartVector:
903918 return false ;
904919 default :
905920 return true ;
@@ -930,6 +945,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
930945 case VPInstruction::CanonicalIVIncrementForPart:
931946 case VPInstruction::BranchOnCount:
932947 case VPInstruction::BranchOnCond:
948+ case VPInstruction::ReductionStartVector:
933949 return true ;
934950 case VPInstruction::PtrAdd:
935951 return Op == getOperand (0 ) || vputils::onlyFirstLaneUsed (this );
@@ -1035,6 +1051,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10351051 case VPInstruction::FirstActiveLane:
10361052 O << " first-active-lane" ;
10371053 break ;
1054+ case VPInstruction::ReductionStartVector:
1055+ O << " reduction-start-vector" ;
1056+ break ;
10381057 default :
10391058 O << Instruction::getOpcodeName (getOpcode ());
10401059 }
@@ -1618,6 +1637,7 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
16181637 Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
16191638 Opcode == Instruction::FCmp || Opcode == Instruction::Select ||
16201639 Opcode == VPInstruction::WideIVStep ||
1640+ Opcode == VPInstruction::ReductionStartVector ||
16211641 Opcode == VPInstruction::ComputeReductionResult;
16221642 case OperationType::NonNegOp:
16231643 return Opcode == Instruction::ZExt;
@@ -3848,17 +3868,19 @@ void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent,
38483868#endif
38493869
38503870void VPReductionPHIRecipe::execute (VPTransformState &State) {
3851- // If this phi is fed by a scaled reduction then it should output a
3852- // vector with fewer elements than the VF .
3853- ElementCount VF = State. VF . divideCoefficientBy (VFScaleFactor );
3871+ // Reductions do not have to start at zero. They can start with
3872+ // any loop invariant values .
3873+ VPValue *StartVPV = getStartValue ( );
38543874
38553875 // In order to support recurrences we need to be able to vectorize Phi nodes.
38563876 // Phi nodes have cycles, so we need to vectorize them in two stages. This is
38573877 // stage #1: We create a new vector PHI node with no incoming edges. We'll use
38583878 // this value when we vectorize all of the instructions that use the PHI.
3859- auto *ScalarTy = State.TypeAnalysis .inferScalarType (this );
3879+ BasicBlock *VectorPH =
3880+ State.CFG .VPBB2IRBB .at (getParent ()->getCFGPredecessor (0 ));
38603881 bool ScalarPHI = State.VF .isScalar () || IsInLoop;
3861- Type *VecTy = ScalarPHI ? ScalarTy : VectorType::get (ScalarTy, VF);
3882+ Value *StartV = State.get (StartVPV, ScalarPHI);
3883+ Type *VecTy = StartV->getType ();
38623884
38633885 BasicBlock *HeaderBB = State.CFG .PrevBB ;
38643886 assert (State.CurrentParentLoop ->getHeader () == HeaderBB &&
@@ -3867,49 +3889,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
38673889 Phi->insertBefore (HeaderBB->getFirstInsertionPt ());
38683890 State.set (this , Phi, IsInLoop);
38693891
3870- BasicBlock *VectorPH =
3871- State.CFG .VPBB2IRBB .at (getParent ()->getCFGPredecessor (0 ));
3872- // Create start and identity vector values for the reduction in the preheader.
3873- // TODO: Introduce recipes in VPlan preheader to create initial values.
3874- IRBuilderBase::InsertPointGuard IPBuilder (State.Builder );
3875- State.Builder .SetInsertPoint (VectorPH->getTerminator ());
3876-
3877- // Reductions do not have to start at zero. They can start with
3878- // any loop invariant values.
3879- VPValue *StartVPV = getStartValue ();
3880- RecurKind RK = RdxDesc.getRecurrenceKind ();
3881- if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RK) ||
3882- RecurrenceDescriptor::isAnyOfRecurrenceKind (RK) ||
3883- RecurrenceDescriptor::isFindLastIVRecurrenceKind (RK)) {
3884- // [I|F]FindLastIV will use a sentinel value to initialize the reduction
3885- // phi or the resume value from the main vector loop when vectorizing the
3886- // epilogue loop. In the exit block, ComputeReductionResult will generate
3887- // checks to verify if the reduction result is the sentinel value. If the
3888- // result is the sentinel value, it will be corrected back to the start
3889- // value.
3890- // TODO: The sentinel value is not always necessary. When the start value is
3891- // a constant, and smaller than the start value of the induction variable,
3892- // the start value can be directly used to initialize the reduction phi.
3893- Phi->addIncoming (State.get (StartVPV, ScalarPHI), VectorPH);
3894- return ;
3895- }
3896-
3897- Value *Iden = getRecurrenceIdentity (RK, VecTy->getScalarType (),
3898- RdxDesc.getFastMathFlags ());
3899- unsigned CurrentPart = getUnrollPart (*this );
3900- Value *StartV = StartVPV->getLiveInIRValue ();
3901- if (!ScalarPHI) {
3902- if (CurrentPart == 0 ) {
3903- Iden = State.Builder .CreateVectorSplat (VF, Iden);
3904- Constant *Zero = State.Builder .getInt32 (0 );
3905- StartV = State.Builder .CreateInsertElement (Iden, StartV, Zero);
3906- } else {
3907- Iden = State.Builder .CreateVectorSplat (VF, Iden);
3908- }
3909- }
3910-
3911- Value *StartVal = (CurrentPart == 0 ) ? StartV : Iden;
3912- Phi->addIncoming (StartVal, VectorPH);
3892+ Phi->addIncoming (StartV, VectorPH);
39133893}
39143894
39153895#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments