Skip to content

Commit 0bfa171

Browse files
authored
[LV] Create in-loop sub reductions (llvm#147026)
This PR allows the loop vectorizer to handle in-loop sub reductions by forming a normal in-loop add reduction with a negated input. Stacked PRs: 1. -> llvm#147026 2. llvm#147255 3. llvm#147302 4. llvm#147513
1 parent 111cdaa commit 0bfa171

File tree

10 files changed

+1428
-13
lines changed

10 files changed

+1428
-13
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ enum class RecurKind {
3535
// clang-format off
3636
None, ///< Not a recurrence.
3737
Add, ///< Sum of integers.
38+
Sub, ///< Subtraction of integers
39+
AddChainWithSubs, ///< A chain of adds and subs
3840
Mul, ///< Product of integers.
3941
Or, ///< Bitwise or logical OR of integers.
4042
And, ///< Bitwise or logical AND of integers.

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
4040
switch (Kind) {
4141
default:
4242
break;
43+
case RecurKind::AddChainWithSubs:
44+
case RecurKind::Sub:
4345
case RecurKind::Add:
4446
case RecurKind::Mul:
4547
case RecurKind::Or:
@@ -897,8 +899,11 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
897899
case Instruction::PHI:
898900
return InstDesc(I, Prev.getRecKind(), Prev.getExactFPMathInst());
899901
case Instruction::Sub:
902+
return InstDesc(
903+
Kind == RecurKind::Sub || Kind == RecurKind::AddChainWithSubs, I);
900904
case Instruction::Add:
901-
return InstDesc(Kind == RecurKind::Add, I);
905+
return InstDesc(
906+
Kind == RecurKind::Add || Kind == RecurKind::AddChainWithSubs, I);
902907
case Instruction::Mul:
903908
return InstDesc(Kind == RecurKind::Mul, I);
904909
case Instruction::And:
@@ -917,7 +922,8 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
917922
I->hasAllowReassoc() ? nullptr : I);
918923
case Instruction::Select:
919924
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
920-
Kind == RecurKind::Add || Kind == RecurKind::Mul)
925+
Kind == RecurKind::Add || Kind == RecurKind::Mul ||
926+
Kind == RecurKind::Sub || Kind == RecurKind::AddChainWithSubs)
921927
return isConditionalRdxPattern(I);
922928
if (isFindIVRecurrenceKind(Kind) && SE)
923929
return isFindIVPattern(Kind, L, OrigPhi, I, *SE);
@@ -1003,6 +1009,17 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
10031009
LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n");
10041010
return true;
10051011
}
1012+
if (AddReductionVar(Phi, RecurKind::Sub, TheLoop, FMF, RedDes, DB, AC, DT,
1013+
SE)) {
1014+
LLVM_DEBUG(dbgs() << "Found a SUB reduction PHI." << *Phi << "\n");
1015+
return true;
1016+
}
1017+
if (AddReductionVar(Phi, RecurKind::AddChainWithSubs, TheLoop, FMF, RedDes,
1018+
DB, AC, DT, SE)) {
1019+
LLVM_DEBUG(dbgs() << "Found a chained ADD-SUB reduction PHI." << *Phi
1020+
<< "\n");
1021+
return true;
1022+
}
10061023
if (AddReductionVar(Phi, RecurKind::Mul, TheLoop, FMF, RedDes, DB, AC, DT,
10071024
SE)) {
10081025
LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n");
@@ -1201,6 +1218,9 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop,
12011218

12021219
unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12031220
switch (Kind) {
1221+
case RecurKind::Sub:
1222+
return Instruction::Sub;
1223+
case RecurKind::AddChainWithSubs:
12041224
case RecurKind::Add:
12051225
return Instruction::Add;
12061226
case RecurKind::Mul:
@@ -1288,6 +1308,10 @@ RecurrenceDescriptor::getReductionOpChain(PHINode *Phi, Loop *L) const {
12881308
if (isFMulAddIntrinsic(Cur))
12891309
return true;
12901310

1311+
if (Cur->getOpcode() == Instruction::Sub &&
1312+
Kind == RecurKind::AddChainWithSubs)
1313+
return true;
1314+
12911315
return Cur->getOpcode() == getOpcode();
12921316
};
12931317

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5190,6 +5190,8 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
51905190
return false;
51915191

51925192
switch (RdxDesc.getRecurrenceKind()) {
5193+
case RecurKind::Sub:
5194+
case RecurKind::AddChainWithSubs:
51935195
case RecurKind::Add:
51945196
case RecurKind::FAdd:
51955197
case RecurKind::And:

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,8 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
914914
switch (RK) {
915915
default:
916916
llvm_unreachable("Unexpected recurrence kind");
917+
case RecurKind::AddChainWithSubs:
918+
case RecurKind::Sub:
917919
case RecurKind::Add:
918920
return Intrinsic::vector_reduce_add;
919921
case RecurKind::Mul:
@@ -1301,6 +1303,8 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13011303
Builder.getFastMathFlags());
13021304
};
13031305
switch (RdxKind) {
1306+
case RecurKind::AddChainWithSubs:
1307+
case RecurKind::Sub:
13041308
case RecurKind::Add:
13051309
case RecurKind::Mul:
13061310
case RecurKind::And:

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9067,6 +9067,16 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
90679067
CurrentLinkI->getFastMathFlags());
90689068
LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator());
90699069
VecOp = FMulRecipe;
9070+
} else if (PhiR->isInLoop() && Kind == RecurKind::AddChainWithSubs &&
9071+
CurrentLinkI->getOpcode() == Instruction::Sub) {
9072+
Type *PhiTy = PhiR->getUnderlyingValue()->getType();
9073+
auto *Zero = Plan->getOrAddLiveIn(ConstantInt::get(PhiTy, 0));
9074+
VPWidenRecipe *Sub = new VPWidenRecipe(
9075+
Instruction::Sub, {Zero, CurrentLink->getOperand(1)}, {},
9076+
VPIRMetadata(), CurrentLinkI->getDebugLoc());
9077+
Sub->setUnderlyingValue(CurrentLinkI);
9078+
LinkVPBB->insert(Sub, CurrentLink->getIterator());
9079+
VecOp = Sub;
90709080
} else {
90719081
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
90729082
if (isa<VPWidenRecipe>(CurrentLink)) {

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24713,6 +24713,8 @@ class HorizontalReduction {
2471324713
case RecurKind::FMinimum:
2471424714
// res = vv
2471524715
break;
24716+
case RecurKind::Sub:
24717+
case RecurKind::AddChainWithSubs:
2471624718
case RecurKind::Mul:
2471724719
case RecurKind::FMul:
2471824720
case RecurKind::FMulAdd:
@@ -24852,6 +24854,8 @@ class HorizontalReduction {
2485224854
case RecurKind::FMinimum:
2485324855
// res = vv
2485424856
return VectorizedValue;
24857+
case RecurKind::Sub:
24858+
case RecurKind::AddChainWithSubs:
2485524859
case RecurKind::Mul:
2485624860
case RecurKind::FMul:
2485724861
case RecurKind::FMulAdd:
@@ -24956,6 +24960,8 @@ class HorizontalReduction {
2495624960
auto *Scale = ConstantVector::get(Vals);
2495724961
return Builder.CreateFMul(VectorizedValue, Scale);
2495824962
}
24963+
case RecurKind::Sub:
24964+
case RecurKind::AddChainWithSubs:
2495924965
case RecurKind::Mul:
2496024966
case RecurKind::FMul:
2496124967
case RecurKind::FMulAdd:

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -812,10 +812,18 @@ Value *VPInstruction::generate(VPTransformState &State) {
812812
Value *RdxPart = RdxParts[Part];
813813
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))
814814
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
815-
else
816-
ReducedPartRdx = Builder.CreateBinOp(
817-
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(RK),
818-
RdxPart, ReducedPartRdx, "bin.rdx");
815+
else {
816+
Instruction::BinaryOps Opcode;
817+
// For sub-recurrences, each UF's reduction variable is already
818+
// negative, we need to do: reduce.add(-acc_uf0 + -acc_uf1)
819+
if (RK == RecurKind::Sub)
820+
Opcode = Instruction::Add;
821+
else
822+
Opcode =
823+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(RK);
824+
ReducedPartRdx =
825+
Builder.CreateBinOp(Opcode, RdxPart, ReducedPartRdx, "bin.rdx");
826+
}
819827
}
820828
}
821829

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1747,7 +1747,8 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) {
17471747
if (!PhiR)
17481748
continue;
17491749
RecurKind RK = PhiR->getRecurrenceKind();
1750-
if (RK != RecurKind::Add && RK != RecurKind::Mul)
1750+
if (RK != RecurKind::Add && RK != RecurKind::Mul && RK != RecurKind::Sub &&
1751+
RK != RecurKind::AddChainWithSubs)
17511752
continue;
17521753

17531754
for (VPUser *U : collectUsersRecursively(PhiR))

0 commit comments

Comments
 (0)