-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[VPlan] Only store RecurKind + FastMathFlags in VPReductionRecipe. NFCI #131300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
e291008
8fcfcb4
eb5082d
c244065
1e00028
23bcb59
3293ebe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2239,22 +2239,21 @@ class VPInterleaveRecipe : public VPRecipeBase { | |
| /// a vector operand into a scalar value, and adding the result to a chain. | ||
| /// The Operands are {ChainOp, VecOp, [Condition]}. | ||
| class VPReductionRecipe : public VPRecipeWithIRFlags { | ||
| /// The recurrence decriptor for the reduction in question. | ||
| const RecurrenceDescriptor &RdxDesc; | ||
| /// The recurrence kind for the reduction in question. | ||
| RecurKind RdxKind; | ||
| bool IsOrdered; | ||
| /// Whether the reduction is conditional. | ||
| bool IsConditional = false; | ||
|
|
||
| protected: | ||
| VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R, | ||
| Instruction *I, ArrayRef<VPValue *> Operands, | ||
| VPValue *CondOp, bool IsOrdered, DebugLoc DL) | ||
| : VPRecipeWithIRFlags(SC, Operands, | ||
| isa_and_nonnull<FPMathOperator>(I) | ||
| ? R.getFastMathFlags() | ||
| : FastMathFlags(), | ||
| DL), | ||
| RdxDesc(R), IsOrdered(IsOrdered) { | ||
| VPReductionRecipe(const unsigned char SC, RecurKind RdxKind, | ||
| FastMathFlags FMFs, Instruction *I, | ||
| ArrayRef<VPValue *> Operands, VPValue *CondOp, | ||
| bool IsOrdered, DebugLoc DL) | ||
| : VPRecipeWithIRFlags( | ||
| SC, Operands, | ||
| isa_and_nonnull<FPMathOperator>(I) ? FMFs : FastMathFlags(), DL), | ||
|
||
| RdxKind(RdxKind), IsOrdered(IsOrdered) { | ||
| if (CondOp) { | ||
| IsConditional = true; | ||
| addOperand(CondOp); | ||
|
|
@@ -2263,19 +2262,19 @@ class VPReductionRecipe : public VPRecipeWithIRFlags { | |
| } | ||
|
|
||
| public: | ||
| VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I, | ||
| VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I, | ||
| VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, | ||
| bool IsOrdered, DebugLoc DL = {}) | ||
| : VPReductionRecipe(VPDef::VPReductionSC, R, I, | ||
| : VPReductionRecipe(VPRecipeBase::VPReductionSC, RdxKind, FMFs, I, | ||
| ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp, | ||
| IsOrdered, DL) {} | ||
|
|
||
| ~VPReductionRecipe() override = default; | ||
|
|
||
| VPReductionRecipe *clone() override { | ||
| return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(), | ||
| getVecOp(), getCondOp(), IsOrdered, | ||
| getDebugLoc()); | ||
| return new VPReductionRecipe(RdxKind, getFastMathFlags(), | ||
| getUnderlyingInstr(), getChainOp(), getVecOp(), | ||
| getCondOp(), IsOrdered, getDebugLoc()); | ||
| } | ||
|
|
||
| static inline bool classof(const VPRecipeBase *R) { | ||
|
|
@@ -2301,9 +2300,11 @@ class VPReductionRecipe : public VPRecipeWithIRFlags { | |
| VPSlotTracker &SlotTracker) const override; | ||
| #endif | ||
|
|
||
| /// Return the recurrence decriptor for the in-loop reduction. | ||
| const RecurrenceDescriptor &getRecurrenceDescriptor() const { | ||
| return RdxDesc; | ||
| /// Return the recurrence kind for the in-loop reduction. | ||
| RecurKind getRecurrenceKind() const { return RdxKind; } | ||
| /// Return the opcode for the recurrence for the in-loop reduction. | ||
| unsigned getOpcode() const { | ||
| return RecurrenceDescriptor::getOpcode(RdxKind); | ||
|
||
| } | ||
| /// Return true if the in-loop reduction is ordered. | ||
| bool isOrdered() const { return IsOrdered; }; | ||
|
|
@@ -2328,7 +2329,8 @@ class VPReductionEVLRecipe : public VPReductionRecipe { | |
| VPReductionEVLRecipe(VPReductionRecipe &R, VPValue &EVL, VPValue *CondOp, | ||
| DebugLoc DL = {}) | ||
| : VPReductionRecipe( | ||
| VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(), | ||
| VPDef::VPReductionEVLSC, R.getRecurrenceKind(), | ||
| R.getFastMathFlags(), | ||
| cast_or_null<Instruction>(R.getUnderlyingValue()), | ||
| ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp, | ||
| R.isOrdered(), DL) {} | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2285,7 +2285,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent, | |
| void VPReductionRecipe::execute(VPTransformState &State) { | ||
| assert(!State.Lane && "Reduction being replicated."); | ||
| Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true); | ||
| RecurKind Kind = RdxDesc.getRecurrenceKind(); | ||
| RecurKind Kind = getRecurrenceKind(); | ||
| assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) && | ||
| "In-loop AnyOf reductions aren't currently supported"); | ||
| // Propagate the fast-math flags carried by the underlying instruction. | ||
|
|
@@ -2298,8 +2298,7 @@ void VPReductionRecipe::execute(VPTransformState &State) { | |
| VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType()); | ||
| Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType(); | ||
|
|
||
| Value *Start = | ||
| getRecurrenceIdentity(Kind, ElementTy, RdxDesc.getFastMathFlags()); | ||
| Value *Start = getRecurrenceIdentity(Kind, ElementTy, getFastMathFlags()); | ||
| if (State.VF.isVector()) | ||
| Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Start); | ||
|
|
||
|
|
@@ -2311,21 +2310,20 @@ void VPReductionRecipe::execute(VPTransformState &State) { | |
| if (IsOrdered) { | ||
| if (State.VF.isVector()) | ||
| NewRed = | ||
| createOrderedReduction(State.Builder, RdxDesc, NewVecOp, PrevInChain); | ||
| createOrderedReduction(State.Builder, Kind, NewVecOp, PrevInChain); | ||
| else | ||
| NewRed = State.Builder.CreateBinOp( | ||
| (Instruction::BinaryOps)RdxDesc.getOpcode(), PrevInChain, NewVecOp); | ||
| NewRed = State.Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), | ||
| PrevInChain, NewVecOp); | ||
| PrevInChain = NewRed; | ||
| NextInChain = NewRed; | ||
| } else { | ||
| PrevInChain = State.get(getChainOp(), /*IsScalar*/ true); | ||
| NewRed = createSimpleReduction(State.Builder, NewVecOp, Kind); | ||
| if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) | ||
| NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(), | ||
| NewRed, PrevInChain); | ||
| NextInChain = createMinMaxOp(State.Builder, Kind, NewRed, PrevInChain); | ||
| else | ||
| NextInChain = State.Builder.CreateBinOp( | ||
| (Instruction::BinaryOps)RdxDesc.getOpcode(), NewRed, PrevInChain); | ||
| (Instruction::BinaryOps)getOpcode(), NewRed, PrevInChain); | ||
| } | ||
| State.set(this, NextInChain, /*IsScalar*/ true); | ||
| } | ||
|
|
@@ -2336,10 +2334,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) { | |
| auto &Builder = State.Builder; | ||
| // Propagate the fast-math flags carried by the underlying instruction. | ||
| IRBuilderBase::FastMathFlagGuard FMFGuard(Builder); | ||
| const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor(); | ||
| Builder.setFastMathFlags(getFastMathFlags()); | ||
|
|
||
| RecurKind Kind = RdxDesc.getRecurrenceKind(); | ||
| RecurKind Kind = getRecurrenceKind(); | ||
| Value *Prev = State.get(getChainOp(), /*IsScalar*/ true); | ||
| Value *VecOp = State.get(getVecOp()); | ||
| Value *EVL = State.get(getEVL(), VPLane(0)); | ||
|
|
@@ -2356,25 +2353,23 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) { | |
|
|
||
| Value *NewRed; | ||
| if (isOrdered()) { | ||
| NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev); | ||
| NewRed = createOrderedReduction(VBuilder, Kind, VecOp, Prev); | ||
| } else { | ||
| NewRed = createSimpleReduction(VBuilder, VecOp, RdxDesc); | ||
| NewRed = createSimpleReduction(VBuilder, VecOp, Kind, getFastMathFlags()); | ||
| if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) | ||
| NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev); | ||
| else | ||
| NewRed = Builder.CreateBinOp((Instruction::BinaryOps)RdxDesc.getOpcode(), | ||
| NewRed, Prev); | ||
| NewRed = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), NewRed, | ||
| Prev); | ||
| } | ||
| State.set(this, NewRed, /*IsScalar*/ true); | ||
| } | ||
|
|
||
| InstructionCost VPReductionRecipe::computeCost(ElementCount VF, | ||
| VPCostContext &Ctx) const { | ||
| RecurKind RdxKind = RdxDesc.getRecurrenceKind(); | ||
| RecurKind RdxKind = getRecurrenceKind(); | ||
| Type *ElementTy = Ctx.Types.inferScalarType(this); | ||
| auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF)); | ||
| unsigned Opcode = RdxDesc.getOpcode(); | ||
| FastMathFlags FMFs = getFastMathFlags(); | ||
|
|
||
| // TODO: Support any-of and in-loop reductions. | ||
| assert( | ||
|
|
@@ -2386,20 +2381,17 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF, | |
| ForceTargetInstructionCost.getNumOccurrences() > 0) && | ||
| "In-loop reduction not implemented in VPlan-based cost model currently."); | ||
|
|
||
| assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() && | ||
| "Inferred type and recurrence type mismatch."); | ||
|
|
||
| // Cost = Reduction cost + BinOp cost | ||
| InstructionCost Cost = | ||
| Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind); | ||
| Ctx.TTI.getArithmeticInstrCost(getOpcode(), ElementTy, Ctx.CostKind); | ||
| if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) { | ||
| Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind); | ||
| return Cost + | ||
| Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind); | ||
| return Cost + Ctx.TTI.getMinMaxReductionCost( | ||
| Id, VectorTy, getFastMathFlags(), Ctx.CostKind); | ||
| } | ||
|
|
||
| return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs, | ||
| Ctx.CostKind); | ||
| return Cost + Ctx.TTI.getArithmeticReductionCost( | ||
| getOpcode(), VectorTy, getFastMathFlags(), Ctx.CostKind); | ||
| } | ||
|
|
||
| #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) | ||
|
|
@@ -2411,28 +2403,24 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, | |
| getChainOp()->printAsOperand(O, SlotTracker); | ||
| O << " +"; | ||
| printFlags(O); | ||
| O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " ("; | ||
| O << " reduce." << Instruction::getOpcodeName(getOpcode()) << " ("; | ||
| getVecOp()->printAsOperand(O, SlotTracker); | ||
| if (isConditional()) { | ||
| O << ", "; | ||
| getCondOp()->printAsOperand(O, SlotTracker); | ||
| } | ||
| O << ")"; | ||
| if (RdxDesc.IntermediateStore) | ||
| O << " (with final reduction value stored in invariant address sank " | ||
| "outside of loop)"; | ||
|
Comment on lines
-2436
to
-2438
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK to drop this, as the store is sunk explicitly. |
||
| } | ||
|
|
||
| void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent, | ||
| VPSlotTracker &SlotTracker) const { | ||
| const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor(); | ||
| O << Indent << "REDUCE "; | ||
| printAsOperand(O, SlotTracker); | ||
| O << " = "; | ||
| getChainOp()->printAsOperand(O, SlotTracker); | ||
| O << " +"; | ||
| printFlags(O); | ||
| O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " ("; | ||
| O << " vp.reduce." << Instruction::getOpcodeName(getOpcode()) << " ("; | ||
| getVecOp()->printAsOperand(O, SlotTracker); | ||
| O << ", "; | ||
| getEVL()->printAsOperand(O, SlotTracker); | ||
|
|
@@ -2441,9 +2429,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent, | |
| getCondOp()->printAsOperand(O, SlotTracker); | ||
| } | ||
| O << ")"; | ||
| if (RdxDesc.IntermediateStore) | ||
| O << " (with final reduction value stored in invariant address sank " | ||
| "outside of loop)"; | ||
| } | ||
| #endif | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.