@@ -666,8 +666,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
666666 RecurrenceDescriptor::isAnyOfRecurrenceKind (RK) ||
667667 RecurrenceDescriptor::isFindLastIVRecurrenceKind (RK)) &&
668668 !PhiR->isInLoop ()) {
669- ReducedPartRdx =
670- createReduction (Builder, RdxDesc, ReducedPartRdx, OrigPhi);
669+ IRBuilderBase::FastMathFlagGuard FMFG (Builder);
670+ Builder.setFastMathFlags (RdxDesc.getFastMathFlags ());
671+ if (RecurrenceDescriptor::isAnyOfRecurrenceKind (RK))
672+ ReducedPartRdx =
673+ createAnyOfReduction (Builder, ReducedPartRdx, RdxDesc, OrigPhi);
674+ else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind (RK))
675+ ReducedPartRdx =
676+ createFindLastIVReduction (Builder, ReducedPartRdx, RdxDesc);
677+ else
678+ ReducedPartRdx = createSimpleReduction (Builder, ReducedPartRdx, RK);
679+
671680 // If the reduction can be performed in a smaller type, we need to extend
672681 // the reduction to the wider type before we branch to the original loop.
673682 if (PhiTy != RdxDesc.getRecurrenceType ())
@@ -2263,21 +2272,21 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
22632272void VPReductionRecipe::execute (VPTransformState &State) {
22642273 assert (!State.Lane && " Reduction being replicated." );
22652274 Value *PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2266- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2275+ RecurKind Kind = getRecurrenceKind ();
22672276 assert (!RecurrenceDescriptor::isAnyOfRecurrenceKind (Kind) &&
22682277 " In-loop AnyOf reductions aren't currently supported" );
2278+
22692279 // Propagate the fast-math flags carried by the underlying instruction.
22702280 IRBuilderBase::FastMathFlagGuard FMFGuard (State.Builder );
2271- State.Builder .setFastMathFlags (RdxDesc. getFastMathFlags ());
2281+ State.Builder .setFastMathFlags (getFastMathFlags ());
22722282 State.setDebugLocFrom (getDebugLoc ());
22732283 Value *NewVecOp = State.get (getVecOp ());
22742284 if (VPValue *Cond = getCondOp ()) {
22752285 Value *NewCond = State.get (Cond, State.VF .isScalar ());
22762286 VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType ());
22772287 Type *ElementTy = VecTy ? VecTy->getElementType () : NewVecOp->getType ();
22782288
2279- Value *Start =
2280- getRecurrenceIdentity (Kind, ElementTy, RdxDesc.getFastMathFlags ());
2289+ Value *Start = getRecurrenceIdentity (Kind, ElementTy, getFastMathFlags ());
22812290 if (State.VF .isVector ())
22822291 Start = State.Builder .CreateVectorSplat (VecTy->getElementCount (), Start);
22832292
@@ -2289,21 +2298,20 @@ void VPReductionRecipe::execute(VPTransformState &State) {
22892298 if (IsOrdered) {
22902299 if (State.VF .isVector ())
22912300 NewRed =
2292- createOrderedReduction (State.Builder , RdxDesc , NewVecOp, PrevInChain);
2301+ createOrderedReduction (State.Builder , Kind , NewVecOp, PrevInChain);
22932302 else
2294- NewRed = State.Builder .CreateBinOp (
2295- (Instruction::BinaryOps)RdxDesc. getOpcode (), PrevInChain, NewVecOp);
2303+ NewRed = State.Builder .CreateBinOp ((Instruction::BinaryOps) getOpcode (),
2304+ PrevInChain, NewVecOp);
22962305 PrevInChain = NewRed;
22972306 NextInChain = NewRed;
22982307 } else {
22992308 PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2300- NewRed = createReduction (State.Builder , RdxDesc, NewVecOp );
2309+ NewRed = createSimpleReduction (State.Builder , NewVecOp, Kind );
23012310 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
2302- NextInChain = createMinMaxOp (State.Builder , RdxDesc.getRecurrenceKind (),
2303- NewRed, PrevInChain);
2311+ NextInChain = createMinMaxOp (State.Builder , Kind, NewRed, PrevInChain);
23042312 else
23052313 NextInChain = State.Builder .CreateBinOp (
2306- (Instruction::BinaryOps)RdxDesc. getOpcode (), NewRed, PrevInChain);
2314+ (Instruction::BinaryOps)getOpcode (), NewRed, PrevInChain);
23072315 }
23082316 State.set (this , NextInChain, /* IsScalar*/ true );
23092317}
@@ -2314,10 +2322,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23142322 auto &Builder = State.Builder ;
23152323 // Propagate the fast-math flags carried by the underlying instruction.
23162324 IRBuilderBase::FastMathFlagGuard FMFGuard (Builder);
2317- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
2318- Builder.setFastMathFlags (RdxDesc.getFastMathFlags ());
2325+ Builder.setFastMathFlags (getFastMathFlags ());
23192326
2320- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2327+ RecurKind Kind = getRecurrenceKind ();
23212328 Value *Prev = State.get (getChainOp (), /* IsScalar*/ true );
23222329 Value *VecOp = State.get (getVecOp ());
23232330 Value *EVL = State.get (getEVL (), VPLane (0 ));
@@ -2334,24 +2341,23 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23342341
23352342 Value *NewRed;
23362343 if (isOrdered ()) {
2337- NewRed = createOrderedReduction (VBuilder, RdxDesc , VecOp, Prev);
2344+ NewRed = createOrderedReduction (VBuilder, Kind , VecOp, Prev);
23382345 } else {
2339- NewRed = createSimpleReduction (VBuilder, VecOp, RdxDesc );
2346+ NewRed = createSimpleReduction (VBuilder, VecOp, Kind, getFastMathFlags () );
23402347 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
23412348 NewRed = createMinMaxOp (Builder, Kind, NewRed, Prev);
23422349 else
2343- NewRed = Builder.CreateBinOp ((Instruction::BinaryOps)RdxDesc. getOpcode (),
2344- NewRed, Prev);
2350+ NewRed = Builder.CreateBinOp ((Instruction::BinaryOps)getOpcode (), NewRed ,
2351+ Prev);
23452352 }
23462353 State.set (this , NewRed, /* IsScalar*/ true );
23472354}
23482355
23492356InstructionCost VPReductionRecipe::computeCost (ElementCount VF,
23502357 VPCostContext &Ctx) const {
2351- RecurKind RdxKind = RdxDesc. getRecurrenceKind ();
2358+ RecurKind RdxKind = getRecurrenceKind ();
23522359 Type *ElementTy = Ctx.Types .inferScalarType (this );
23532360 auto *VectorTy = cast<VectorType>(toVectorTy (ElementTy, VF));
2354- unsigned Opcode = RdxDesc.getOpcode ();
23552361
23562362 // TODO: Support any-of and in-loop reductions.
23572363 assert (
@@ -2363,20 +2369,17 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23632369 ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
23642370 " In-loop reduction not implemented in VPlan-based cost model currently." );
23652371
2366- assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
2367- " Inferred type and recurrence type mismatch." );
2368-
23692372 // Cost = Reduction cost + BinOp cost
23702373 InstructionCost Cost =
2371- Ctx.TTI .getArithmeticInstrCost (Opcode , ElementTy, Ctx.CostKind );
2374+ Ctx.TTI .getArithmeticInstrCost (getOpcode () , ElementTy, Ctx.CostKind );
23722375 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
23732376 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
23742377 return Cost + Ctx.TTI .getMinMaxReductionCost (
2375- Id, VectorTy, RdxDesc. getFastMathFlags (), Ctx.CostKind );
2378+ Id, VectorTy, getFastMathFlags (), Ctx.CostKind );
23762379 }
23772380
23782381 return Cost + Ctx.TTI .getArithmeticReductionCost (
2379- Opcode , VectorTy, RdxDesc. getFastMathFlags (), Ctx.CostKind );
2382+ getOpcode () , VectorTy, getFastMathFlags (), Ctx.CostKind );
23802383}
23812384
23822385#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2389,29 +2392,31 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
23892392 O << " +" ;
23902393 if (isa<FPMathOperator>(getUnderlyingInstr ()))
23912394 O << getUnderlyingInstr ()->getFastMathFlags ();
2392- O << " reduce." << Instruction::getOpcodeName (RdxDesc.getOpcode ()) << " (" ;
2395+ O << " reduce."
2396+ << Instruction::getOpcodeName (
2397+ RecurrenceDescriptor::getOpcode (getRecurrenceKind ()))
2398+ << " (" ;
23932399 getVecOp ()->printAsOperand (O, SlotTracker);
23942400 if (isConditional ()) {
23952401 O << " , " ;
23962402 getCondOp ()->printAsOperand (O, SlotTracker);
23972403 }
23982404 O << " )" ;
2399- if (RdxDesc.IntermediateStore )
2400- O << " (with final reduction value stored in invariant address sank "
2401- " outside of loop)" ;
24022405}
24032406
24042407void VPReductionEVLRecipe::print (raw_ostream &O, const Twine &Indent,
24052408 VPSlotTracker &SlotTracker) const {
2406- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
2409+ RecurKind Kind = getRecurrenceKind ();
24072410 O << Indent << " REDUCE " ;
24082411 printAsOperand (O, SlotTracker);
24092412 O << " = " ;
24102413 getChainOp ()->printAsOperand (O, SlotTracker);
24112414 O << " +" ;
24122415 if (isa<FPMathOperator>(getUnderlyingInstr ()))
24132416 O << getUnderlyingInstr ()->getFastMathFlags ();
2414- O << " vp.reduce." << Instruction::getOpcodeName (RdxDesc.getOpcode ()) << " (" ;
2417+ O << " vp.reduce."
2418+ << Instruction::getOpcodeName (RecurrenceDescriptor::getOpcode (Kind))
2419+ << " (" ;
24152420 getVecOp ()->printAsOperand (O, SlotTracker);
24162421 O << " , " ;
24172422 getEVL ()->printAsOperand (O, SlotTracker);
@@ -2420,9 +2425,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
24202425 getCondOp ()->printAsOperand (O, SlotTracker);
24212426 }
24222427 O << " )" ;
2423- if (RdxDesc.IntermediateStore )
2424- O << " (with final reduction value stored in invariant address sank "
2425- " outside of loop)" ;
24262428}
24272429#endif
24282430
0 commit comments