Skip to content

Commit 396698f

Browse files
authored
Revert "Revert "[LoopVectorizer] Add support for chaining partial reductions …"
This reverts commit 0e21383.
1 parent 0e21383 commit 396698f

File tree

4 files changed

+1073
-25
lines changed

4 files changed

+1073
-25
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8682,12 +8682,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
86828682
/// are valid so recipes can be formed later.
86838683
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
86848684
// Find all possible partial reductions.
8685-
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8685+
SmallVector<std::pair<PartialReductionChain, unsigned>>
86868686
PartialReductionChains;
8687-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8688-
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8689-
getScaledReduction(Phi, RdxDesc, Range))
8690-
PartialReductionChains.push_back(*Pair);
8687+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8688+
if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
8689+
PartialReductionChains.append(*SR);
8690+
}
86918691

86928692
// A partial reduction is invalid if any of its extends are used by
86938693
// something that isn't another partial reduction. This is because the
@@ -8715,26 +8715,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87158715
}
87168716
}
87178717

8718-
std::optional<std::pair<PartialReductionChain, unsigned>>
8719-
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8720-
const RecurrenceDescriptor &Rdx,
8718+
std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
8719+
VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
87218720
VFRange &Range) {
8721+
8722+
if (!CM.TheLoop->contains(RdxExitInstr))
8723+
return std::nullopt;
8724+
87228725
// TODO: Allow scaling reductions when predicating. The select at
87238726
// the end of the loop chooses between the phi value and most recent
87248727
// reduction result, both of which have different VFs to the active lane
87258728
// mask when scaling.
8726-
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8729+
if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
87278730
return std::nullopt;
87288731

8729-
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8732+
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
87308733
if (!Update)
87318734
return std::nullopt;
87328735

87338736
Value *Op = Update->getOperand(0);
87348737
Value *PhiOp = Update->getOperand(1);
8735-
if (Op == PHI) {
8736-
Op = Update->getOperand(1);
8737-
PhiOp = Update->getOperand(0);
8738+
if (Op == PHI)
8739+
std::swap(Op, PhiOp);
8740+
8741+
SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
8742+
8743+
// Try and get a scaled reduction from the first non-phi operand.
8744+
// If one is found, we use the discovered reduction instruction in
8745+
// place of the accumulator for costing.
8746+
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747+
if (auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
8748+
Chains.append(*SR0);
8749+
PHI = SR0->rbegin()->first.Reduction;
8750+
8751+
Op = Update->getOperand(0);
8752+
PhiOp = Update->getOperand(1);
8753+
if (Op == PHI)
8754+
std::swap(Op, PhiOp);
8755+
}
87388756
}
87398757
if (PhiOp != PHI)
87408758
return std::nullopt;
@@ -8757,7 +8775,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87578775
TTI::PartialReductionExtendKind OpBExtend =
87588776
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
87598777

8760-
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8778+
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
87618779

87628780
unsigned TargetScaleFactor =
87638781
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8772,9 +8790,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87728790
return Cost.isValid();
87738791
},
87748792
Range))
8775-
return std::make_pair(Chain, TargetScaleFactor);
8793+
Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
87768794

8777-
return std::nullopt;
8795+
return Chains;
87788796
}
87798797

87808798
VPRecipeBase *
@@ -8869,12 +8887,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88698887
"Unexpected number of operands for partial reduction");
88708888

88718889
VPValue *BinOp = Operands[0];
8872-
VPValue *Phi = Operands[1];
8873-
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8874-
std::swap(BinOp, Phi);
8875-
8876-
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8877-
Reduction);
8890+
VPValue *Accumulator = Operands[1];
8891+
VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
8892+
if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8893+
isa<VPPartialReductionRecipe>(BinOpRecipe))
8894+
std::swap(BinOp, Accumulator);
8895+
8896+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp,
8897+
Accumulator, Reduction);
88788898
}
88798899

88808900
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class VPRecipeBuilder {
142142
/// Returns null if no scaled reduction was found, otherwise a pair with a
143143
/// struct containing reduction information and the scaling factor between the
144144
/// number of elements in the input and output.
145-
std::optional<std::pair<PartialReductionChain, unsigned>>
146-
getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
145+
std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
146+
getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
147147
VFRange &Range);
148148

149149
public:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2455,7 +2455,10 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
24552455
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
24562456
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
24572457
Opcode(Opcode) {
2458-
assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
2458+
[[maybe_unused]] auto *AccumulatorRecipe =
2459+
getOperand(1)->getDefiningRecipe();
2460+
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
2461+
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
24592462
"Unexpected operand order for partial reduction recipe");
24602463
}
24612464
~VPPartialReductionRecipe() override = default;

0 commit comments

Comments
 (0)