@@ -2300,7 +2300,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
23002300void VPReductionRecipe::execute (VPTransformState &State) {
23012301 assert (!State.Lane && " Reduction being replicated." );
23022302 Value *PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2303- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2303+ RecurKind Kind = getRecurrenceKind ();
23042304 assert (!RecurrenceDescriptor::isAnyOfRecurrenceKind (Kind) &&
23052305 " In-loop AnyOf reductions aren't currently supported" );
23062306 // Propagate the fast-math flags carried by the underlying instruction.
@@ -2313,8 +2313,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
23132313 VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType ());
23142314 Type *ElementTy = VecTy ? VecTy->getElementType () : NewVecOp->getType ();
23152315
2316- Value *Start =
2317- getRecurrenceIdentity (Kind, ElementTy, RdxDesc.getFastMathFlags ());
2316+ Value *Start = getRecurrenceIdentity (Kind, ElementTy, getFastMathFlags ());
23182317 if (State.VF .isVector ())
23192318 Start = State.Builder .CreateVectorSplat (VecTy->getElementCount (), Start);
23202319
@@ -2329,18 +2328,19 @@ void VPReductionRecipe::execute(VPTransformState &State) {
23292328 createOrderedReduction (State.Builder , Kind, NewVecOp, PrevInChain);
23302329 else
23312330 NewRed = State.Builder .CreateBinOp (
2332- (Instruction::BinaryOps)RdxDesc.getOpcode (), PrevInChain, NewVecOp);
2331+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind),
2332+ PrevInChain, NewVecOp);
23332333 PrevInChain = NewRed;
23342334 NextInChain = NewRed;
23352335 } else {
23362336 PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
23372337 NewRed = createSimpleReduction (State.Builder , NewVecOp, Kind);
23382338 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
2339- NextInChain = createMinMaxOp (State.Builder , RdxDesc.getRecurrenceKind (),
2340- NewRed, PrevInChain);
2339+ NextInChain = createMinMaxOp (State.Builder , Kind, NewRed, PrevInChain);
23412340 else
23422341 NextInChain = State.Builder .CreateBinOp (
2343- (Instruction::BinaryOps)RdxDesc.getOpcode (), NewRed, PrevInChain);
2342+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind), NewRed,
2343+ PrevInChain);
23442344 }
23452345 State.set (this , NextInChain, /* IsScalar*/ true );
23462346}
@@ -2351,10 +2351,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23512351 auto &Builder = State.Builder ;
23522352 // Propagate the fast-math flags carried by the underlying instruction.
23532353 IRBuilderBase::FastMathFlagGuard FMFGuard (Builder);
2354- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
23552354 Builder.setFastMathFlags (getFastMathFlags ());
23562355
2357- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2356+ RecurKind Kind = getRecurrenceKind ();
23582357 Value *Prev = State.get (getChainOp (), /* IsScalar*/ true );
23592358 Value *VecOp = State.get (getVecOp ());
23602359 Value *EVL = State.get (getEVL (), VPLane (0 ));
@@ -2377,18 +2376,19 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
23772376 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
23782377 NewRed = createMinMaxOp (Builder, Kind, NewRed, Prev);
23792378 else
2380- NewRed = Builder.CreateBinOp ((Instruction::BinaryOps)RdxDesc.getOpcode (),
2381- NewRed, Prev);
2379+ NewRed = Builder.CreateBinOp (
2380+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind), NewRed,
2381+ Prev);
23822382 }
23832383 State.set (this , NewRed, /* IsScalar*/ true );
23842384}
23852385
23862386InstructionCost VPReductionRecipe::computeCost (ElementCount VF,
23872387 VPCostContext &Ctx) const {
2388- RecurKind RdxKind = RdxDesc. getRecurrenceKind ();
2388+ RecurKind RdxKind = getRecurrenceKind ();
23892389 Type *ElementTy = Ctx.Types .inferScalarType (this );
23902390 auto *VectorTy = cast<VectorType>(toVectorTy (ElementTy, VF));
2391- unsigned Opcode = RdxDesc. getOpcode ();
2391+ unsigned Opcode = RecurrenceDescriptor:: getOpcode (RdxKind );
23922392 FastMathFlags FMFs = getFastMathFlags ();
23932393
23942394 // TODO: Support any-of and in-loop reductions.
@@ -2401,9 +2401,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
24012401 ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
24022402 " In-loop reduction not implemented in VPlan-based cost model currently." );
24032403
2404- assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
2405- " Inferred type and recurrence type mismatch." );
2406-
24072404 // Cost = Reduction cost + BinOp cost
24082405 InstructionCost Cost =
24092406 Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, Ctx.CostKind );
@@ -2426,28 +2423,30 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
24262423 getChainOp ()->printAsOperand (O, SlotTracker);
24272424 O << " +" ;
24282425 printFlags (O);
2429- O << " reduce." << Instruction::getOpcodeName (RdxDesc.getOpcode ()) << " (" ;
2426+ O << " reduce."
2427+ << Instruction::getOpcodeName (
2428+ RecurrenceDescriptor::getOpcode (getRecurrenceKind ()))
2429+ << " (" ;
24302430 getVecOp ()->printAsOperand (O, SlotTracker);
24312431 if (isConditional ()) {
24322432 O << " , " ;
24332433 getCondOp ()->printAsOperand (O, SlotTracker);
24342434 }
24352435 O << " )" ;
2436- if (RdxDesc.IntermediateStore )
2437- O << " (with final reduction value stored in invariant address sank "
2438- " outside of loop)" ;
24392436}
24402437
24412438void VPReductionEVLRecipe::print (raw_ostream &O, const Twine &Indent,
24422439 VPSlotTracker &SlotTracker) const {
2443- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
24442440 O << Indent << " REDUCE " ;
24452441 printAsOperand (O, SlotTracker);
24462442 O << " = " ;
24472443 getChainOp ()->printAsOperand (O, SlotTracker);
24482444 O << " +" ;
24492445 printFlags (O);
2450- O << " vp.reduce." << Instruction::getOpcodeName (RdxDesc.getOpcode ()) << " (" ;
2446+ O << " vp.reduce."
2447+ << Instruction::getOpcodeName (
2448+ RecurrenceDescriptor::getOpcode (getRecurrenceKind ()))
2449+ << " (" ;
24512450 getVecOp ()->printAsOperand (O, SlotTracker);
24522451 O << " , " ;
24532452 getEVL ()->printAsOperand (O, SlotTracker);
@@ -2456,9 +2455,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
24562455 getCondOp ()->printAsOperand (O, SlotTracker);
24572456 }
24582457 O << " )" ;
2459- if (RdxDesc.IntermediateStore )
2460- O << " (with final reduction value stored in invariant address sank "
2461- " outside of loop)" ;
24622458}
24632459#endif
24642460
0 commit comments