@@ -604,6 +604,20 @@ Value *VPInstruction::generate(VPTransformState &State) {
604604 return Builder.CreateVectorSplat (
605605 State.VF , State.get (getOperand (0 ), /* IsScalar*/ true ), " broadcast" );
606606 }
607+ case VPInstruction::ReductionStartVector: {
608+ if (State.VF .isScalar ())
609+ return State.get (getOperand (0 ), true );
610+ IRBuilderBase::FastMathFlagGuard FMFG (Builder);
611+ Builder.setFastMathFlags (getFastMathFlags ());
612+ // If this start vector is scaled then it should produce a vector with fewer
613+ // elements than the VF.
614+ ElementCount VF = State.VF .divideCoefficientBy (
615+ cast<ConstantInt>(getOperand (2 )->getLiveInIRValue ())->getZExtValue ());
616+ auto *Iden = Builder.CreateVectorSplat (VF, State.get (getOperand (1 ), true ));
617+ Constant *Zero = Builder.getInt32 (0 );
618+ return Builder.CreateInsertElement (Iden, State.get (getOperand (0 ), true ),
619+ Zero);
620+ }
607621 case VPInstruction::ComputeAnyOfResult: {
608622 // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
609623 // and will be removed by breaking up the recipe further.
@@ -882,6 +896,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
882896 case VPInstruction::PtrAdd:
883897 case VPInstruction::WideIVStep:
884898 case VPInstruction::StepVector:
899+ case VPInstruction::ReductionStartVector:
885900 return false ;
886901 default :
887902 return true ;
@@ -912,6 +927,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
912927 case VPInstruction::CanonicalIVIncrementForPart:
913928 case VPInstruction::BranchOnCount:
914929 case VPInstruction::BranchOnCond:
930+ case VPInstruction::ReductionStartVector:
915931 return true ;
916932 case VPInstruction::PtrAdd:
917933 return Op == getOperand (0 ) || vputils::onlyFirstLaneUsed (this );
@@ -1017,6 +1033,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10171033 case VPInstruction::FirstActiveLane:
10181034 O << " first-active-lane" ;
10191035 break ;
1036+ case VPInstruction::ReductionStartVector:
1037+ O << " reduction-start-vector" ;
1038+ break ;
10201039 default :
10211040 O << Instruction::getOpcodeName (getOpcode ());
10221041 }
@@ -1601,6 +1620,7 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
16011620 Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
16021621 Opcode == Instruction::FCmp || Opcode == Instruction::Select ||
16031622 Opcode == VPInstruction::WideIVStep ||
1623+ Opcode == VPInstruction::ReductionStartVector ||
16041624 Opcode == VPInstruction::ComputeReductionResult;
16051625 case OperationType::NonNegOp:
16061626 return Opcode == Instruction::ZExt;
@@ -3831,17 +3851,19 @@ void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent,
38313851#endif
38323852
38333853void VPReductionPHIRecipe::execute (VPTransformState &State) {
3834- // If this phi is fed by a scaled reduction then it should output a
3835- // vector with fewer elements than the VF .
3836- ElementCount VF = State. VF . divideCoefficientBy (VFScaleFactor );
3854+ // Reductions do not have to start at zero. They can start with
3855+ // any loop invariant values .
3856+ VPValue *StartVPV = getStartValue ( );
38373857
38383858 // In order to support recurrences we need to be able to vectorize Phi nodes.
38393859 // Phi nodes have cycles, so we need to vectorize them in two stages. This is
38403860 // stage #1: We create a new vector PHI node with no incoming edges. We'll use
38413861 // this value when we vectorize all of the instructions that use the PHI.
3842- auto *ScalarTy = State.TypeAnalysis .inferScalarType (this );
3862+ BasicBlock *VectorPH =
3863+ State.CFG .VPBB2IRBB .at (getParent ()->getCFGPredecessor (0 ));
38433864 bool ScalarPHI = State.VF .isScalar () || IsInLoop;
3844- Type *VecTy = ScalarPHI ? ScalarTy : VectorType::get (ScalarTy, VF);
3865+ Value *StartV = State.get (StartVPV, ScalarPHI);
3866+ Type *VecTy = StartV->getType ();
38453867
38463868 BasicBlock *HeaderBB = State.CFG .PrevBB ;
38473869 assert (State.CurrentParentLoop ->getHeader () == HeaderBB &&
@@ -3850,49 +3872,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
38503872 Phi->insertBefore (HeaderBB->getFirstInsertionPt ());
38513873 State.set (this , Phi, IsInLoop);
38523874
3853- BasicBlock *VectorPH =
3854- State.CFG .VPBB2IRBB .at (getParent ()->getCFGPredecessor (0 ));
3855- // Create start and identity vector values for the reduction in the preheader.
3856- // TODO: Introduce recipes in VPlan preheader to create initial values.
3857- IRBuilderBase::InsertPointGuard IPBuilder (State.Builder );
3858- State.Builder .SetInsertPoint (VectorPH->getTerminator ());
3859-
3860- // Reductions do not have to start at zero. They can start with
3861- // any loop invariant values.
3862- VPValue *StartVPV = getStartValue ();
3863- RecurKind RK = RdxDesc.getRecurrenceKind ();
3864- if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RK) ||
3865- RecurrenceDescriptor::isAnyOfRecurrenceKind (RK) ||
3866- RecurrenceDescriptor::isFindLastIVRecurrenceKind (RK)) {
3867- // [I|F]FindLastIV will use a sentinel value to initialize the reduction
3868- // phi or the resume value from the main vector loop when vectorizing the
3869- // epilogue loop. In the exit block, ComputeReductionResult will generate
3870- // checks to verify if the reduction result is the sentinel value. If the
3871- // result is the sentinel value, it will be corrected back to the start
3872- // value.
3873- // TODO: The sentinel value is not always necessary. When the start value is
3874- // a constant, and smaller than the start value of the induction variable,
3875- // the start value can be directly used to initialize the reduction phi.
3876- Phi->addIncoming (State.get (StartVPV, ScalarPHI), VectorPH);
3877- return ;
3878- }
3879-
3880- Value *Iden = getRecurrenceIdentity (RK, VecTy->getScalarType (),
3881- RdxDesc.getFastMathFlags ());
3882- unsigned CurrentPart = getUnrollPart (*this );
3883- Value *StartV = StartVPV->getLiveInIRValue ();
3884- if (!ScalarPHI) {
3885- if (CurrentPart == 0 ) {
3886- Iden = State.Builder .CreateVectorSplat (VF, Iden);
3887- Constant *Zero = State.Builder .getInt32 (0 );
3888- StartV = State.Builder .CreateInsertElement (Iden, StartV, Zero);
3889- } else {
3890- Iden = State.Builder .CreateVectorSplat (VF, Iden);
3891- }
3892- }
3893-
3894- Value *StartVal = (CurrentPart == 0 ) ? StartV : Iden;
3895- Phi->addIncoming (StartVal, VectorPH);
3875+ Phi->addIncoming (StartV, VectorPH);
38963876}
38973877
38983878#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments