Skip to content

Commit d5caa62

Browse files
committed
[LV] Use VPReductionRecipe for partial reductions
Partial reductions can easily be represented by the VPReductionRecipe class by setting their scale factor to something greater than 1. This PR merges the two together and gives VPReductionRecipe a VFScaleFactor so that it can choose to generate the partial reduction intrinsic at execute time. This also leads to partial reductions naturally being included in VPMulAccumulateRecipe, which is nice for hiding the cost of the extends and mul, but it does have the side effect of generating an unnecessary extend for chained partial reduction cases. I don't think this can be avoided nicely, and it should be eliminated by DCE anyway.
1 parent 49c6235 commit d5caa62

File tree

15 files changed

+688
-1015
lines changed

15 files changed

+688
-1015
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ class TargetTransformInfo {
223223
/// Get the kind of extension that an instruction represents.
224224
LLVM_ABI static PartialReductionExtendKind
225225
getPartialReductionExtendKind(Instruction *I);
226+
LLVM_ABI static PartialReductionExtendKind
227+
getPartialReductionExtendKind(Instruction::CastOps CastOpc);
226228

227229
/// Construct a TTI object using a type implementing the \c Concept
228230
/// API below.

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -994,11 +994,22 @@ InstructionCost TargetTransformInfo::getShuffleCost(
994994
}
995995

996996
TargetTransformInfo::PartialReductionExtendKind
997-
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
998-
if (isa<SExtInst>(I))
999-
return PR_SignExtend;
1000-
if (isa<ZExtInst>(I))
997+
TargetTransformInfo::getPartialReductionExtendKind(
998+
Instruction::CastOps CastOpc) {
999+
switch (CastOpc) {
1000+
case Instruction::CastOps::ZExt:
10011001
return PR_ZeroExtend;
1002+
case Instruction::CastOps::SExt:
1003+
return PR_SignExtend;
1004+
default:
1005+
return PR_None;
1006+
}
1007+
}
1008+
1009+
TargetTransformInfo::PartialReductionExtendKind
1010+
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
1011+
if (auto *Cast = dyn_cast<CastInst>(I))
1012+
return getPartialReductionExtendKind(Cast->getOpcode());
10021013
return PR_None;
10031014
}
10041015

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7050,7 +7050,8 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
70507050
}
70517051
// The VPlan-based cost model is more accurate for partial reduction and
70527052
// comparing against the legacy cost isn't desirable.
7053-
if (isa<VPPartialReductionRecipe>(&R))
7053+
if (auto *VPR = dyn_cast<VPReductionRecipe>(&R);
7054+
VPR && VPR->isPartialReduction())
70547055
return true;
70557056

70567057
/// If a VPlan transform folded a recipe to one producing a single-scalar,
@@ -8278,11 +8279,15 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
82788279
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
82798280

82808281
// If the PHI is used by a partial reduction, set the scale factor.
8281-
unsigned ScaleFactor =
8282-
getScalingForReduction(RdxDesc.getLoopExitInstr()).value_or(1);
8283-
PhiRecipe = new VPReductionPHIRecipe(
8284-
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8285-
CM.useOrderedReductions(RdxDesc), ScaleFactor);
8282+
bool UseInLoopReduction = CM.isInLoopReduction(Phi);
8283+
bool UseOrderedReductions = CM.useOrderedReductions(RdxDesc);
8284+
auto ScaleFactor = ElementCount::getFixed(
8285+
(UseOrderedReductions || UseInLoopReduction)
8286+
? 0
8287+
: getScalingForReduction(RdxDesc.getLoopExitInstr()).value_or(1));
8288+
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8289+
CM.isInLoopReduction(Phi),
8290+
UseOrderedReductions, ScaleFactor);
82868291
} else {
82878292
// TODO: Currently fixed-order recurrences are modeled as chains of
82888293
// first-order recurrences. If there are no users of the intermediate
@@ -8315,7 +8320,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
83158320
return tryToWidenMemory(Instr, Operands, Range);
83168321

83178322
if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr))
8318-
return tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value());
8323+
return tryToCreatePartialReduction(
8324+
Instr, Operands, ElementCount::getFixed(ScaleFactor.value()));
83198325

83208326
if (!shouldWiden(Instr, Range))
83218327
return nullptr;
@@ -8338,15 +8344,16 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
83388344
VPRecipeBase *
83398345
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
83408346
ArrayRef<VPValue *> Operands,
8341-
unsigned ScaleFactor) {
8347+
ElementCount ScaleFactor) {
83428348
assert(Operands.size() == 2 &&
83438349
"Unexpected number of operands for partial reduction");
83448350

83458351
VPValue *BinOp = Operands[0];
83468352
VPValue *Accumulator = Operands[1];
83478353
VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
83488354
if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8349-
isa<VPPartialReductionRecipe>(BinOpRecipe))
8355+
(isa<VPReductionRecipe>(BinOpRecipe) &&
8356+
cast<VPReductionRecipe>(BinOpRecipe)->isPartialReduction()))
83508357
std::swap(BinOp, Accumulator);
83518358

83528359
unsigned ReductionOpcode = Reduction->getOpcode();
@@ -8367,12 +8374,10 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
83678374
"Expected an ADD or SUB operation for predicated partial "
83688375
"reductions (because the neutral element in the mask is zero)!");
83698376
Cond = getBlockInMask(Builder.getInsertBlock());
8370-
VPValue *Zero =
8371-
Plan.getOrAddLiveIn(ConstantInt::get(Reduction->getType(), 0));
8372-
BinOp = Builder.createSelect(Cond, BinOp, Zero, Reduction->getDebugLoc());
83738377
}
8374-
return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
8375-
ScaleFactor, Reduction);
8378+
8379+
return new VPReductionRecipe(RecurKind::Add, FastMathFlags(), Reduction,
8380+
Accumulator, BinOp, Cond, false, ScaleFactor);
83768381
}
83778382

83788383
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
@@ -9139,9 +9144,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
91399144
FastMathFlags FMFs = isa<FPMathOperator>(CurrentLinkI)
91409145
? RdxDesc.getFastMathFlags()
91419146
: FastMathFlags();
9147+
bool UseOrderedReductions = CM.useOrderedReductions(RdxDesc);
9148+
ElementCount VFScaleFactor =
9149+
ElementCount::getFixed(!UseOrderedReductions);
91429150
auto *RedRecipe = new VPReductionRecipe(
91439151
Kind, FMFs, CurrentLinkI, PreviousLink, VecOp, CondOp,
9144-
CM.useOrderedReductions(RdxDesc), CurrentLinkI->getDebugLoc());
9152+
UseOrderedReductions, VFScaleFactor, CurrentLinkI->getDebugLoc());
91459153
// Append the recipe to the end of the VPBasicBlock because we need to
91469154
// ensure that it comes after all of it's inputs, including CondOp.
91479155
// Delete CurrentLink as it will be invalid if its operand is replaced
@@ -9175,8 +9183,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
91759183
// Don't output selects for partial reductions because they have an output
91769184
// with fewer lanes than the VF. So the operands of the select would have
91779185
// different numbers of lanes. Partial reductions mask the input instead.
9186+
auto *RR = dyn_cast<VPReductionRecipe>(OrigExitingVPV->getDefiningRecipe());
91789187
if (!PhiR->isInLoop() && CM.foldTailByMasking() &&
9179-
!isa<VPPartialReductionRecipe>(OrigExitingVPV->getDefiningRecipe())) {
9188+
(!RR || !RR->isPartialReduction())) {
91809189
VPValue *Cond = RecipeBuilder.getBlockInMask(PhiR->getParent());
91819190
std::optional<FastMathFlags> FMFs =
91829191
PhiTy->isFloatingPointTy()

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class VPRecipeBuilder {
172172
/// along with binary operation and reduction phi operands.
173173
VPRecipeBase *tryToCreatePartialReduction(Instruction *Reduction,
174174
ArrayRef<VPValue *> Operands,
175-
unsigned ScaleFactor);
175+
ElementCount ScaleFactor);
176176

177177
/// Set the recipe created for given ingredient.
178178
void setRecipe(Instruction *I, VPRecipeBase *R) {

0 commit comments

Comments
 (0)