Skip to content

Commit d2cbd8f

Browse files
committed
[VPlan] Detect and create partial reductions in VPlan. (NFCI)
As a first step, move the existing partial reduction detection logic to VPlan, trying to preserve the existing code structure & behavior as closely as possible. With this, partial reductions are detected and created together in a single step. This allows forming partial reductions and bundling them up if profitable together in a follow-up.
1 parent 909c9aa commit d2cbd8f

File tree

6 files changed

+413
-293
lines changed

6 files changed

+413
-293
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 13 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -7985,178 +7985,6 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI,
79857985
return Recipe;
79867986
}
79877987

7988-
/// Find all possible partial reductions in the loop and track all of those that
7989-
/// are valid so recipes can be formed later.
7990-
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
7991-
// Find all possible partial reductions.
7992-
SmallVector<std::pair<PartialReductionChain, unsigned>>
7993-
PartialReductionChains;
7994-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
7995-
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
7996-
PartialReductionChains);
7997-
}
7998-
7999-
// A partial reduction is invalid if any of its extends are used by
8000-
// something that isn't another partial reduction. This is because the
8001-
// extends are intended to be lowered along with the reduction itself.
8002-
8003-
// Build up a set of partial reduction ops for efficient use checking.
8004-
SmallPtrSet<User *, 4> PartialReductionOps;
8005-
for (const auto &[PartialRdx, _] : PartialReductionChains)
8006-
PartialReductionOps.insert(PartialRdx.ExtendUser);
8007-
8008-
auto ExtendIsOnlyUsedByPartialReductions =
8009-
[&PartialReductionOps](Instruction *Extend) {
8010-
return all_of(Extend->users(), [&](const User *U) {
8011-
return PartialReductionOps.contains(U);
8012-
});
8013-
};
8014-
8015-
// Check if each use of a chain's two extends is a partial reduction
8016-
// and only add those that don't have non-partial reduction users.
8017-
for (auto Pair : PartialReductionChains) {
8018-
PartialReductionChain Chain = Pair.first;
8019-
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8020-
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
8021-
ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
8022-
}
8023-
8024-
// Check that all partial reductions in a chain are only used by other
8025-
// partial reductions with the same scale factor. Otherwise we end up creating
8026-
// users of scaled reductions where the types of the other operands don't
8027-
// match.
8028-
for (const auto &[Chain, Scale] : PartialReductionChains) {
8029-
auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
8030-
auto *UI = cast<Instruction>(U);
8031-
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
8032-
return all_of(UI->users(), [ScaleVal, this](const User *U) {
8033-
auto *UI = cast<Instruction>(U);
8034-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
8035-
});
8036-
}
8037-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
8038-
!OrigLoop->contains(UI->getParent());
8039-
};
8040-
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx))
8041-
ScaledReductionMap.erase(Chain.Reduction);
8042-
}
8043-
}
8044-
8045-
bool VPRecipeBuilder::getScaledReductions(
8046-
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8047-
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
8048-
if (!CM.TheLoop->contains(RdxExitInstr))
8049-
return false;
8050-
8051-
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
8052-
if (!Update)
8053-
return false;
8054-
8055-
Value *Op = Update->getOperand(0);
8056-
Value *PhiOp = Update->getOperand(1);
8057-
if (Op == PHI)
8058-
std::swap(Op, PhiOp);
8059-
8060-
// Try and get a scaled reduction from the first non-phi operand.
8061-
// If one is found, we use the discovered reduction instruction in
8062-
// place of the accumulator for costing.
8063-
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8064-
if (getScaledReductions(PHI, OpInst, Range, Chains)) {
8065-
PHI = Chains.rbegin()->first.Reduction;
8066-
8067-
Op = Update->getOperand(0);
8068-
PhiOp = Update->getOperand(1);
8069-
if (Op == PHI)
8070-
std::swap(Op, PhiOp);
8071-
}
8072-
}
8073-
if (PhiOp != PHI)
8074-
return false;
8075-
8076-
using namespace llvm::PatternMatch;
8077-
8078-
// If the update is a binary operator, check both of its operands to see if
8079-
// they are extends. Otherwise, see if the update comes directly from an
8080-
// extend.
8081-
Instruction *Exts[2] = {nullptr};
8082-
BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
8083-
std::optional<unsigned> BinOpc;
8084-
Type *ExtOpTypes[2] = {nullptr};
8085-
TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None};
8086-
8087-
auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
8088-
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
8089-
for (const auto &[I, OpI] : enumerate(Ops)) {
8090-
const APInt *C;
8091-
if (I > 0 && match(OpI, m_APInt(C)) &&
8092-
canConstantBeExtended(C, ExtOpTypes[0], ExtKinds[0])) {
8093-
ExtOpTypes[I] = ExtOpTypes[0];
8094-
ExtKinds[I] = ExtKinds[0];
8095-
continue;
8096-
}
8097-
Value *ExtOp;
8098-
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
8099-
return false;
8100-
Exts[I] = cast<Instruction>(OpI);
8101-
8102-
// TODO: We should be able to support live-ins.
8103-
if (!CM.TheLoop->contains(Exts[I]))
8104-
return false;
8105-
8106-
ExtOpTypes[I] = ExtOp->getType();
8107-
ExtKinds[I] = TTI::getPartialReductionExtendKind(Exts[I]);
8108-
}
8109-
return true;
8110-
};
8111-
8112-
if (ExtendUser) {
8113-
if (!ExtendUser->hasOneUse())
8114-
return false;
8115-
8116-
// Use the side-effect of match to replace BinOp only if the pattern is
8117-
// matched, we don't care at this point whether it actually matched.
8118-
match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
8119-
8120-
SmallVector<Value *> Ops(ExtendUser->operands());
8121-
if (!CollectExtInfo(Ops))
8122-
return false;
8123-
8124-
BinOpc = std::make_optional(ExtendUser->getOpcode());
8125-
} else if (match(Update, m_Add(m_Value(), m_Value()))) {
8126-
// We already know the operands for Update are Op and PhiOp.
8127-
SmallVector<Value *> Ops({Op});
8128-
if (!CollectExtInfo(Ops))
8129-
return false;
8130-
8131-
ExtendUser = Update;
8132-
BinOpc = std::nullopt;
8133-
} else
8134-
return false;
8135-
8136-
PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
8137-
8138-
TypeSize PHISize = PHI->getType()->getPrimitiveSizeInBits();
8139-
TypeSize ASize = ExtOpTypes[0]->getPrimitiveSizeInBits();
8140-
if (!PHISize.hasKnownScalarFactor(ASize))
8141-
return false;
8142-
unsigned TargetScaleFactor = PHISize.getKnownScalarFactor(ASize);
8143-
8144-
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8145-
[&](ElementCount VF) {
8146-
InstructionCost Cost = TTI->getPartialReductionCost(
8147-
Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
8148-
PHI->getType(), VF, ExtKinds[0], ExtKinds[1], BinOpc,
8149-
CM.CostKind);
8150-
return Cost.isValid();
8151-
},
8152-
Range)) {
8153-
Chains.emplace_back(Chain, TargetScaleFactor);
8154-
return true;
8155-
}
8156-
8157-
return false;
8158-
}
8159-
81607988
VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
81617989
VFRange &Range) {
81627990
// First, check for specific widening recipes that deal with inductions, Phi
@@ -8183,12 +8011,11 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
81838011
assert(RdxDesc.getRecurrenceStartValue() ==
81848012
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
81858013

8186-
// If the PHI is used by a partial reduction, set the scale factor.
8187-
unsigned ScaleFactor =
8188-
getScalingForReduction(RdxDesc.getLoopExitInstr()).value_or(1);
8189-
PhiRecipe = new VPReductionPHIRecipe(
8190-
Phi, RdxDesc.getRecurrenceKind(), *StartV, CM.isInLoopReduction(Phi),
8191-
CM.useOrderedReductions(RdxDesc), ScaleFactor);
8014+
// Always create with scale factor 1. Partial reductions will be created
8015+
// later in createPartialReductions transform.
8016+
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc.getRecurrenceKind(),
8017+
*StartV, CM.isInLoopReduction(Phi),
8018+
CM.useOrderedReductions(RdxDesc));
81928019
} else {
81938020
// TODO: Currently fixed-order recurrences are modeled as chains of
81948021
// first-order recurrences. If there are no users of the intermediate
@@ -8224,9 +8051,6 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
82248051
VPI->getOpcode() == Instruction::Store)
82258052
return tryToWidenMemory(VPI, Range);
82268053

8227-
if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr))
8228-
return tryToCreatePartialReduction(VPI, ScaleFactor.value());
8229-
82308054
if (!shouldWiden(Instr, Range))
82318055
return nullptr;
82328056

@@ -8247,41 +8071,6 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
82478071
return tryToWiden(VPI);
82488072
}
82498073

8250-
VPRecipeBase *
8251-
VPRecipeBuilder::tryToCreatePartialReduction(VPInstruction *Reduction,
8252-
unsigned ScaleFactor) {
8253-
assert(Reduction->getNumOperands() == 2 &&
8254-
"Unexpected number of operands for partial reduction");
8255-
8256-
VPValue *BinOp = Reduction->getOperand(0);
8257-
VPValue *Accumulator = Reduction->getOperand(1);
8258-
if (isa<VPReductionPHIRecipe>(BinOp) || isa<VPPartialReductionRecipe>(BinOp))
8259-
std::swap(BinOp, Accumulator);
8260-
8261-
assert(ScaleFactor ==
8262-
vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()) &&
8263-
"all accumulators in chain must have same scale factor");
8264-
8265-
unsigned ReductionOpcode = Reduction->getOpcode();
8266-
auto *ReductionI = Reduction->getUnderlyingInstr();
8267-
if (ReductionOpcode == Instruction::Sub) {
8268-
auto *const Zero = ConstantInt::get(ReductionI->getType(), 0);
8269-
SmallVector<VPValue *, 2> Ops;
8270-
Ops.push_back(Plan.getOrAddLiveIn(Zero));
8271-
Ops.push_back(BinOp);
8272-
BinOp = new VPWidenRecipe(*ReductionI, Ops, VPIRMetadata(),
8273-
ReductionI->getDebugLoc());
8274-
Builder.insert(BinOp->getDefiningRecipe());
8275-
ReductionOpcode = Instruction::Add;
8276-
}
8277-
8278-
VPValue *Cond = nullptr;
8279-
if (CM.blockNeedsPredicationForAnyReason(ReductionI->getParent()))
8280-
Cond = getBlockInMask(Builder.getInsertBlock());
8281-
return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
8282-
ScaleFactor, ReductionI);
8283-
}
8284-
82858074
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
82868075
ElementCount MaxVF) {
82878076
if (ElementCount::isKnownGT(MinVF, MaxVF))
@@ -8408,11 +8197,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
84088197
// Construct wide recipes and apply predication for original scalar
84098198
// VPInstructions in the loop.
84108199
// ---------------------------------------------------------------------------
8411-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
8412-
Builder, BlockMaskCache);
8413-
// TODO: Handle partial reductions with EVL tail folding.
8414-
if (!CM.foldTailWithEVL())
8415-
RecipeBuilder.collectScaledReductions(Range);
8200+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder,
8201+
BlockMaskCache);
84168202

84178203
// Scan the body of the loop in a topological order to visit each basic block
84188204
// after having visited its predecessor basic blocks.
@@ -8521,11 +8307,10 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
85218307
*Plan))
85228308
return nullptr;
85238309

8524-
// Transform recipes to abstract recipes if it is legal and beneficial and
8525-
// clamp the range for better cost estimation.
8526-
// TODO: Enable following transform when the EVL-version of extended-reduction
8527-
// and mulacc-reduction are implemented.
85288310
if (!CM.foldTailWithEVL()) {
8311+
// Create partial reduction recipes for scaled reductions.
8312+
VPlanTransforms::createPartialReductions(*Plan, Range, &TTI, CM.CostKind);
8313+
85298314
VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind,
85308315
*CM.PSE.getSE(), OrigLoop);
85318316
VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
@@ -8606,8 +8391,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) {
86068391
// Collect mapping of IR header phis to header phi recipes, to be used in
86078392
// addScalarResumePhis.
86088393
DenseMap<VPBasicBlock *, VPValue *> BlockMaskCache;
8609-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
8610-
Builder, BlockMaskCache);
8394+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder,
8395+
BlockMaskCache);
86118396
for (auto &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
86128397
if (isa<VPCanonicalIVPHIRecipe>(&R))
86138398
continue;
@@ -8957,11 +8742,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
89578742
VPBuilder PHBuilder(Plan->getVectorPreheader());
89588743
VPValue *Iden = Plan->getOrAddLiveIn(
89598744
getRecurrenceIdentity(RK, PhiTy, RdxDesc.getFastMathFlags()));
8960-
// If the PHI is used by a partial reduction, set the scale factor.
8961-
unsigned ScaleFactor =
8962-
RecipeBuilder.getScalingForReduction(RdxDesc.getLoopExitInstr())
8963-
.value_or(1);
8964-
auto *ScaleFactorVPV = Plan->getConstantInt(32, ScaleFactor);
8745+
auto *ScaleFactorVPV = Plan->getConstantInt(32, 1);
89658746
VPValue *StartV = PHBuilder.createNaryOp(
89668747
VPInstruction::ReductionStartVector,
89678748
{PhiR->getStartValue(), Iden, ScaleFactorVPV},

0 commit comments

Comments
 (0)