-
Notifications
You must be signed in to change notification settings - Fork 15k
[LV] Bundle partial reductions inside VPExpressionRecipe #147302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8abd55d
a623911
5f9f9a5
de9afb5
e5c610f
37cf515
fbea834
875dfa3
1e3b04c
fa09d29
f4fa801
1168c24
d2a50d9
deb2b06
b08d207
0785325
1956e8e
62490e4
026935f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1001,13 +1001,25 @@ InstructionCost TargetTransformInfo::getShuffleCost( | |
|
|
||
| TargetTransformInfo::PartialReductionExtendKind | ||
| TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) { | ||
| if (isa<SExtInst>(I)) | ||
| return PR_SignExtend; | ||
| if (isa<ZExtInst>(I)) | ||
| return PR_ZeroExtend; | ||
| if (auto *Cast = dyn_cast<CastInst>(I)) | ||
| return getPartialReductionExtendKind(Cast->getOpcode()); | ||
| return PR_None; | ||
| } | ||
|
|
||
| TargetTransformInfo::PartialReductionExtendKind | ||
| TargetTransformInfo::getPartialReductionExtendKind( | ||
| Instruction::CastOps CastOpc) { | ||
| switch (CastOpc) { | ||
| case Instruction::CastOps::ZExt: | ||
| return PR_ZeroExtend; | ||
| case Instruction::CastOps::SExt: | ||
| return PR_SignExtend; | ||
| default: | ||
| return PR_None; | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you might need llvm_unreachable() at the bottom, I think some bots will complain otherwise There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, thanks. |
||
| llvm_unreachable("Unhandled cast opcode"); | ||
| } | ||
|
|
||
| TTI::CastContextHint | ||
| TargetTransformInfo::getCastContextHint(const Instruction *I) { | ||
| if (!I) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2712,7 +2712,8 @@ class LLVM_ABI_FOR_TEST VPReductionRecipe : public VPRecipeWithIRFlags { | |
|
|
||
| static inline bool classof(const VPRecipeBase *R) { | ||
| return R->getVPDefID() == VPRecipeBase::VPReductionSC || | ||
| R->getVPDefID() == VPRecipeBase::VPReductionEVLSC; | ||
| R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || | ||
| R->getVPDefID() == VPRecipeBase::VPPartialReductionSC; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this was missed before and only now is tested? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's right. |
||
| } | ||
|
|
||
| static inline bool classof(const VPUser *U) { | ||
|
|
@@ -2783,7 +2784,10 @@ class VPPartialReductionRecipe : public VPReductionRecipe { | |
| Opcode(Opcode), VFScaleFactor(ScaleFactor) { | ||
| [[maybe_unused]] auto *AccumulatorRecipe = | ||
| getChainOp()->getDefiningRecipe(); | ||
| assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) || | ||
| // When cloning as part of a VPExpressionRecipe the chain op could have | ||
| // replaced by a temporary VPValue, so it doesn't have a defining recipe. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for this case, can we assert that it is a live-in w/o underlying value? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to assert that? |
||
| assert((!AccumulatorRecipe || | ||
| isa<VPReductionPHIRecipe>(AccumulatorRecipe) || | ||
| isa<VPPartialReductionRecipe>(AccumulatorRecipe)) && | ||
| "Unexpected operand order for partial reduction recipe"); | ||
| } | ||
|
|
@@ -3093,6 +3097,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe { | |
| /// removed before codegen. | ||
| void decompose(); | ||
|
|
||
| unsigned getVFScaleFactor() const { | ||
| auto *PR = dyn_cast<VPPartialReductionRecipe>(ExpressionRecipes.back()); | ||
| return PR ? PR->getVFScaleFactor() : 1; | ||
| } | ||
|
|
||
| /// Method for generating code, must not be called as this recipe is abstract. | ||
| void execute(VPTransformState &State) override { | ||
| llvm_unreachable("recipe must be removed before execute"); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,6 +168,7 @@ bool VPRecipeBase::mayHaveSideEffects() const { | |
| return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects(); | ||
| case VPBlendSC: | ||
| case VPReductionEVLSC: | ||
| case VPPartialReductionSC: | ||
| case VPReductionSC: | ||
| case VPScalarIVStepsSC: | ||
| case VPVectorPointerSC: | ||
|
|
@@ -300,14 +301,23 @@ InstructionCost | |
| VPPartialReductionRecipe::computeCost(ElementCount VF, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at the change you made to this function, it made me realise that this entirely function can be simplified to: Because there is no need to re-analyse all the expressions again, all this information should have already been expressed by a VPExpression's and its corresponding cost model. (Note that the cost-model for AArch64 will return That being said, using VPExpressions for partial reductions doesn't currently support predicated vector loops, because that introduces another It would also be good to add a FIXME to say that that complicated cost-model code here should be removed after fully migrating to VPExpressions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, thank you. |
||
| VPCostContext &Ctx) const { | ||
| std::optional<unsigned> Opcode; | ||
| VPValue *Op = getOperand(0); | ||
| VPRecipeBase *OpR = Op->getDefiningRecipe(); | ||
|
|
||
| // If the partial reduction is predicated, a select will be operand 0 | ||
| if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) { | ||
| OpR = Op->getDefiningRecipe(); | ||
| VPValue *Op = getVecOp(); | ||
| uint64_t MulConst; | ||
| // If the partial reduction is predicated, a select will be operand 1. | ||
| // If it isn't predicated and the mul isn't operating on a constant, then it | ||
| // should have been turned into a VPExpressionRecipe. | ||
| // FIXME: Replace the entire function with this once all partial reduction | ||
| // variants are bundled into VPExpressionRecipe. | ||
| if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) && | ||
| !match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) { | ||
| auto *PhiType = Ctx.Types.inferScalarType(getChainOp()); | ||
| auto *InputType = Ctx.Types.inferScalarType(getVecOp()); | ||
| return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType, | ||
| PhiType, VF, TTI::PR_None, | ||
| TTI::PR_None, {}, Ctx.CostKind); | ||
| } | ||
|
|
||
| VPRecipeBase *OpR = Op->getDefiningRecipe(); | ||
| Type *InputTypeA = nullptr, *InputTypeB = nullptr; | ||
| TTI::PartialReductionExtendKind ExtAType = TTI::PR_None, | ||
| ExtBType = TTI::PR_None; | ||
|
|
@@ -2856,11 +2866,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, | |
| cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind()); | ||
| switch (ExpressionType) { | ||
| case ExpressionTypes::ExtendedReduction: { | ||
| return Ctx.TTI.getExtendedReductionCost( | ||
| Opcode, | ||
| cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == | ||
| Instruction::ZExt, | ||
| RedTy, SrcVecTy, std::nullopt, Ctx.CostKind); | ||
| unsigned Opcode = RecurrenceDescriptor::getOpcode( | ||
| cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind()); | ||
| auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); | ||
| return isa<VPPartialReductionRecipe>(ExpressionRecipes.back()) | ||
| ? Ctx.TTI.getPartialReductionCost( | ||
| Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr, | ||
| RedTy, VF, | ||
| TargetTransformInfo::getPartialReductionExtendKind( | ||
| ExtR->getOpcode()), | ||
| TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind) | ||
| : Ctx.TTI.getExtendedReductionCost( | ||
| Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy, | ||
| SrcVecTy, std::nullopt, Ctx.CostKind); | ||
| } | ||
| case ExpressionTypes::MulAccReduction: | ||
| return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy, | ||
|
|
@@ -2871,6 +2889,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, | |
| Opcode = Instruction::Sub; | ||
| [[fallthrough]]; | ||
| case ExpressionTypes::ExtMulAccReduction: { | ||
| if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) { | ||
gbossu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); | ||
| auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]); | ||
|
Comment on lines
+2893
to
+2894
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work as expected for all test on current main? I think at least in some cases one of the operands may be a constant live-in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah the matching function in |
||
| auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); | ||
| return Ctx.TTI.getPartialReductionCost( | ||
| Opcode, Ctx.Types.inferScalarType(getOperand(0)), | ||
| Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF, | ||
| TargetTransformInfo::getPartialReductionExtendKind( | ||
| Ext0R->getOpcode()), | ||
| TargetTransformInfo::getPartialReductionExtendKind( | ||
| Ext1R->getOpcode()), | ||
| Mul->getOpcode(), Ctx.CostKind); | ||
| } | ||
| return Ctx.TTI.getMulAccReductionCost( | ||
| cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == | ||
| Instruction::ZExt, | ||
|
|
@@ -2910,12 +2941,13 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, | |
| O << " = "; | ||
| auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back()); | ||
| unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); | ||
| bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red); | ||
|
|
||
| switch (ExpressionType) { | ||
| case ExpressionTypes::ExtendedReduction: { | ||
| getOperand(1)->printAsOperand(O, SlotTracker); | ||
| O << " +"; | ||
| O << " reduce." << Instruction::getOpcodeName(Opcode) << " ("; | ||
| O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; | ||
| O << Instruction::getOpcodeName(Opcode) << " ("; | ||
| getOperand(0)->printAsOperand(O, SlotTracker); | ||
| Red->printFlags(O); | ||
|
|
||
|
|
@@ -2931,8 +2963,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, | |
| } | ||
| case ExpressionTypes::ExtNegatedMulAccReduction: { | ||
| getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); | ||
| O << " + reduce." | ||
| << Instruction::getOpcodeName( | ||
| O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; | ||
| O << Instruction::getOpcodeName( | ||
| RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) | ||
| << " (sub (0, mul"; | ||
| auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); | ||
|
|
@@ -2956,9 +2988,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, | |
| case ExpressionTypes::MulAccReduction: | ||
| case ExpressionTypes::ExtMulAccReduction: { | ||
| getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); | ||
| O << " + "; | ||
| O << "reduce." | ||
| << Instruction::getOpcodeName( | ||
| O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; | ||
| O << Instruction::getOpcodeName( | ||
| RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) | ||
| << " ("; | ||
| O << "mul"; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to add a brief comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.