Skip to content

Commit 0fffb9f

Browse files
committed
[SLP]Initial FMAD support (#149102)
Added initial check for potential fmad conversion in reductions and operands vectorization. Added the check for instruction to fix #152683
1 parent ef30dd3 commit 0fffb9f

13 files changed

+841
-292
lines changed

llvm/docs/ReleaseNotes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Changes to Vectorizers
7777

7878
* Added initial support for copyable elements in SLP, which models copyable
7979
elements as add <element>, 0, i.e. uses identity constants for missing lanes.
80+
* SLP vectorizer supports initial recognition of FMA/FMAD pattern
8081

8182
Changes to the AArch64 Backend
8283
------------------------------

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 158 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3883,6 +3883,7 @@ class BoUpSLP {
38833883
enum CombinedOpcode {
38843884
NotCombinedOp = -1,
38853885
MinMax = Instruction::OtherOpsEnd + 1,
3886+
FMulAdd,
38863887
};
38873888
CombinedOpcode CombinedOp = NotCombinedOp;
38883889

@@ -4033,6 +4034,9 @@ class BoUpSLP {
40334034
/// Returns true if any scalar in the list is a copyable element.
40344035
bool hasCopyableElements() const { return !CopyableElements.empty(); }
40354036

4037+
/// Returns the state of the operations.
4038+
const InstructionsState &getOperations() const { return S; }
4039+
40364040
/// When ReuseReorderShuffleIndices is empty it just returns position of \p
40374041
/// V within vector of Scalars. Otherwise, try to remap on its reuse index.
40384042
unsigned findLaneForValue(Value *V) const {
@@ -11987,6 +11991,82 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
1198711991
}
1198811992
}
1198911993

11994+
static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
11995+
const InstructionsState &S,
11996+
DominatorTree &DT, const DataLayout &DL,
11997+
TargetTransformInfo &TTI,
11998+
const TargetLibraryInfo &TLI) {
11999+
assert(all_of(VL,
12000+
[](Value *V) {
12001+
return V->getType()->getScalarType()->isFloatingPointTy();
12002+
}) &&
12003+
"Can only convert to FMA for floating point types");
12004+
assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
12005+
12006+
auto CheckForContractable = [&](ArrayRef<Value *> VL) {
12007+
FastMathFlags FMF;
12008+
FMF.set();
12009+
for (Value *V : VL) {
12010+
auto *I = dyn_cast<Instruction>(V);
12011+
if (!I)
12012+
continue;
12013+
// TODO: support for copyable elements.
12014+
Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
12015+
if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
12016+
continue;
12017+
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12018+
FMF &= FPCI->getFastMathFlags();
12019+
}
12020+
return FMF.allowContract();
12021+
};
12022+
if (!CheckForContractable(VL))
12023+
return InstructionCost::getInvalid();
12024+
// fmul also should be contractable
12025+
InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
12026+
SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
12027+
12028+
InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
12029+
if (!OpS.valid())
12030+
return InstructionCost::getInvalid();
12031+
if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
12032+
return InstructionCost::getInvalid();
12033+
if (!CheckForContractable(Operands.front()))
12034+
return InstructionCost::getInvalid();
12035+
// Compare the costs.
12036+
InstructionCost FMulPlusFAddCost = 0;
12037+
InstructionCost FMACost = 0;
12038+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
12039+
FastMathFlags FMF;
12040+
FMF.set();
12041+
for (Value *V : VL) {
12042+
auto *I = dyn_cast<Instruction>(V);
12043+
if (!I)
12044+
continue;
12045+
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12046+
FMF &= FPCI->getFastMathFlags();
12047+
FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12048+
}
12049+
unsigned NumOps = 0;
12050+
for (auto [V, Op] : zip(VL, Operands.front())) {
12051+
auto *I = dyn_cast<Instruction>(Op);
12052+
if (!I || !I->hasOneUse()) {
12053+
if (auto *OpI = dyn_cast<Instruction>(V))
12054+
FMACost += TTI.getInstructionCost(OpI, CostKind);
12055+
if (I)
12056+
FMACost += TTI.getInstructionCost(I, CostKind);
12057+
continue;
12058+
}
12059+
++NumOps;
12060+
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12061+
FMF &= FPCI->getFastMathFlags();
12062+
FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12063+
}
12064+
Type *Ty = VL.front()->getType();
12065+
IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
12066+
FMACost += NumOps * TTI.getIntrinsicInstrCost(ICA, CostKind);
12067+
return FMACost < FMulPlusFAddCost ? FMACost : InstructionCost::getInvalid();
12068+
}
12069+
1199012070
void BoUpSLP::transformNodes() {
1199112071
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1199212072
BaseGraphSize = VectorizableTree.size();
@@ -12355,6 +12435,25 @@ void BoUpSLP::transformNodes() {
1235512435
}
1235612436
break;
1235712437
}
12438+
case Instruction::FSub:
12439+
case Instruction::FAdd: {
12440+
// Check if possible to convert (a*b)+c to fma.
12441+
if (E.State != TreeEntry::Vectorize ||
12442+
!E.getOperations().isAddSubLikeOp())
12443+
break;
12444+
if (!canConvertToFMA(E.Scalars, E.getOperations(), *DT, *DL, *TTI, *TLI)
12445+
.isValid())
12446+
break;
12447+
// This node is a fmuladd node.
12448+
E.CombinedOp = TreeEntry::FMulAdd;
12449+
TreeEntry *FMulEntry = getOperandEntry(&E, 0);
12450+
if (FMulEntry->UserTreeIndex &&
12451+
FMulEntry->State == TreeEntry::Vectorize) {
12452+
// The FMul node is part of the combined fmuladd node.
12453+
FMulEntry->State = TreeEntry::CombinedVectorize;
12454+
}
12455+
break;
12456+
}
1235812457
default:
1235912458
break;
1236012459
}
@@ -13587,6 +13686,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1358713686
}
1358813687
return IntrinsicCost;
1358913688
};
13689+
auto GetFMulAddCost = [&, &TTI = *TTI](const InstructionsState &S,
13690+
Instruction *VI) {
13691+
InstructionCost Cost = canConvertToFMA(VI, S, *DT, *DL, TTI, *TLI);
13692+
return Cost;
13693+
};
1359013694
switch (ShuffleOrOp) {
1359113695
case Instruction::PHI: {
1359213696
// Count reused scalars.
@@ -13927,6 +14031,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1392714031
};
1392814032
return GetCostDiff(GetScalarCost, GetVectorCost);
1392914033
}
14034+
case TreeEntry::FMulAdd: {
14035+
auto GetScalarCost = [&](unsigned Idx) {
14036+
if (isa<PoisonValue>(UniqueValues[Idx]))
14037+
return InstructionCost(TTI::TCC_Free);
14038+
return GetFMulAddCost(E->getOperations(),
14039+
cast<Instruction>(UniqueValues[Idx]));
14040+
};
14041+
auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
14042+
FastMathFlags FMF;
14043+
FMF.set();
14044+
for (Value *V : E->Scalars) {
14045+
if (auto *FPCI = dyn_cast<FPMathOperator>(V)) {
14046+
FMF &= FPCI->getFastMathFlags();
14047+
if (auto *FPCIOp = dyn_cast<FPMathOperator>(FPCI->getOperand(0)))
14048+
FMF &= FPCIOp->getFastMathFlags();
14049+
}
14050+
}
14051+
IntrinsicCostAttributes ICA(Intrinsic::fmuladd, VecTy,
14052+
{VecTy, VecTy, VecTy}, FMF);
14053+
InstructionCost VecCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
14054+
return VecCost + CommonCost;
14055+
};
14056+
return GetCostDiff(GetScalarCost, GetVectorCost);
14057+
}
1393014058
case Instruction::FNeg:
1393114059
case Instruction::Add:
1393214060
case Instruction::FAdd:
@@ -13964,8 +14092,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1396414092
}
1396514093
TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(Op1);
1396614094
TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(Op2);
13967-
return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
13968-
Op1Info, Op2Info, Operands);
14095+
InstructionCost ScalarCost = TTI->getArithmeticInstrCost(
14096+
ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands);
14097+
if (auto *I = dyn_cast<Instruction>(UniqueValues[Idx]);
14098+
I && (ShuffleOrOp == Instruction::FAdd ||
14099+
ShuffleOrOp == Instruction::FSub)) {
14100+
InstructionCost IntrinsicCost = GetFMulAddCost(E->getOperations(), I);
14101+
if (IntrinsicCost.isValid())
14102+
ScalarCost = IntrinsicCost;
14103+
}
14104+
return ScalarCost;
1396914105
};
1397014106
auto GetVectorCost = [=](InstructionCost CommonCost) {
1397114107
if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) {
@@ -22594,11 +22730,21 @@ class HorizontalReduction {
2259422730
/// Try to find a reduction tree.
2259522731
bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
2259622732
ScalarEvolution &SE, const DataLayout &DL,
22597-
const TargetLibraryInfo &TLI) {
22733+
const TargetLibraryInfo &TLI,
22734+
DominatorTree &DT, TargetTransformInfo &TTI) {
2259822735
RdxKind = HorizontalReduction::getRdxKind(Root);
2259922736
if (!isVectorizable(RdxKind, Root))
2260022737
return false;
2260122738

22739+
// FMA reduction root - skip.
22740+
auto CheckForFMA = [&](Instruction *I) {
22741+
return RdxKind == RecurKind::FAdd &&
22742+
canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
22743+
.isValid();
22744+
};
22745+
if (CheckForFMA(Root))
22746+
return false;
22747+
2260222748
// Analyze "regular" integer/FP types for reductions - no target-specific
2260322749
// types or pointers.
2260422750
Type *Ty = Root->getType();
@@ -22636,7 +22782,7 @@ class HorizontalReduction {
2263622782
// Also, do not try to reduce const values, if the operation is not
2263722783
// foldable.
2263822784
if (!EdgeInst || Level > RecursionMaxDepth ||
22639-
getRdxKind(EdgeInst) != RdxKind ||
22785+
getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
2264022786
IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
2264122787
!hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
2264222788
!isVectorizable(RdxKind, EdgeInst) ||
@@ -24205,13 +24351,13 @@ bool SLPVectorizerPass::vectorizeHorReduction(
2420524351
Stack.emplace(SelectRoot(), 0);
2420624352
SmallPtrSet<Value *, 8> VisitedInstrs;
2420724353
bool Res = false;
24208-
auto &&TryToReduce = [this, &R](Instruction *Inst) -> Value * {
24354+
auto TryToReduce = [this, &R, TTI = TTI](Instruction *Inst) -> Value * {
2420924355
if (R.isAnalyzedReductionRoot(Inst))
2421024356
return nullptr;
2421124357
if (!isReductionCandidate(Inst))
2421224358
return nullptr;
2421324359
HorizontalReduction HorRdx;
24214-
if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
24360+
if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI))
2421524361
return nullptr;
2421624362
return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
2421724363
};
@@ -24277,6 +24423,12 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
2427724423

2427824424
if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType()))
2427924425
return false;
24426+
// Skip potential FMA candidates.
24427+
if ((I->getOpcode() == Instruction::FAdd ||
24428+
I->getOpcode() == Instruction::FSub) &&
24429+
canConvertToFMA(I, getSameOpcode(I, *TLI), *DT, *DL, *TTI, *TLI)
24430+
.isValid())
24431+
return false;
2428024432

2428124433
Value *P = I->getParent();
2428224434

llvm/test/Transforms/SLPVectorizer/AArch64/commute.ll

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@ target triple = "aarch64--linux-gnu"
88
define void @test1(ptr nocapture readonly %J, i32 %xmin, i32 %ymin) {
99
; CHECK-LABEL: @test1(
1010
; CHECK-NEXT: entry:
11-
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[XMIN:%.*]], i32 0
12-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> [[TMP0]], i32 [[YMIN:%.*]], i32 1
1311
; CHECK-NEXT: br label [[FOR_BODY3_LR_PH:%.*]]
1412
; CHECK: for.body3.lr.ph:
15-
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i32> [[TMP1]] to <2 x float>
16-
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[J:%.*]], align 4
17-
; CHECK-NEXT: [[TMP5:%.*]] = fsub fast <2 x float> [[TMP2]], [[TMP4]]
18-
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x float> [[TMP5]], [[TMP5]]
19-
; CHECK-NEXT: [[ADD:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP6]])
13+
; CHECK-NEXT: [[CONV5:%.*]] = sitofp i32 [[YMIN:%.*]] to float
14+
; CHECK-NEXT: [[CONV:%.*]] = sitofp i32 [[XMIN:%.*]] to float
15+
; CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[J:%.*]], align 4
16+
; CHECK-NEXT: [[SUB:%.*]] = fsub fast float [[CONV]], [[TMP0]]
17+
; CHECK-NEXT: [[ARRAYIDX9:%.*]] = getelementptr inbounds [[STRUCTA:%.*]], ptr [[J]], i64 0, i32 0, i64 1
18+
; CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[ARRAYIDX9]], align 4
19+
; CHECK-NEXT: [[SUB10:%.*]] = fsub fast float [[CONV5]], [[TMP1]]
20+
; CHECK-NEXT: [[MUL11:%.*]] = fmul fast float [[SUB]], [[SUB]]
21+
; CHECK-NEXT: [[MUL12:%.*]] = fmul fast float [[SUB10]], [[SUB10]]
22+
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[MUL11]], [[MUL12]]
2023
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[ADD]], 0.000000e+00
2124
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY3_LR_PH]], label [[FOR_END27:%.*]]
2225
; CHECK: for.end27:
@@ -47,15 +50,18 @@ for.end27:
4750
define void @test2(ptr nocapture readonly %J, i32 %xmin, i32 %ymin) {
4851
; CHECK-LABEL: @test2(
4952
; CHECK-NEXT: entry:
50-
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[XMIN:%.*]], i32 0
51-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> [[TMP0]], i32 [[YMIN:%.*]], i32 1
5253
; CHECK-NEXT: br label [[FOR_BODY3_LR_PH:%.*]]
5354
; CHECK: for.body3.lr.ph:
54-
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i32> [[TMP1]] to <2 x float>
55-
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[J:%.*]], align 4
56-
; CHECK-NEXT: [[TMP5:%.*]] = fsub fast <2 x float> [[TMP2]], [[TMP4]]
57-
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x float> [[TMP5]], [[TMP5]]
58-
; CHECK-NEXT: [[ADD:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP6]])
55+
; CHECK-NEXT: [[CONV5:%.*]] = sitofp i32 [[YMIN:%.*]] to float
56+
; CHECK-NEXT: [[CONV:%.*]] = sitofp i32 [[XMIN:%.*]] to float
57+
; CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[J:%.*]], align 4
58+
; CHECK-NEXT: [[SUB:%.*]] = fsub fast float [[CONV]], [[TMP0]]
59+
; CHECK-NEXT: [[ARRAYIDX9:%.*]] = getelementptr inbounds [[STRUCTA:%.*]], ptr [[J]], i64 0, i32 0, i64 1
60+
; CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[ARRAYIDX9]], align 4
61+
; CHECK-NEXT: [[SUB10:%.*]] = fsub fast float [[CONV5]], [[TMP1]]
62+
; CHECK-NEXT: [[MUL11:%.*]] = fmul fast float [[SUB]], [[SUB]]
63+
; CHECK-NEXT: [[MUL12:%.*]] = fmul fast float [[SUB10]], [[SUB10]]
64+
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[MUL12]], [[MUL11]]
5965
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[ADD]], 0.000000e+00
6066
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY3_LR_PH]], label [[FOR_END27:%.*]]
6167
; CHECK: for.end27:

0 commit comments

Comments
 (0)