Skip to content

[SLP]Initial FMAD support #149102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/docs/ReleaseNotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Changes to Vectorizers

* Added initial support for copyable elements in SLP, which models copyable
elements as add <element>, 0, i.e. uses identity constants for missing lanes.
* SLP vectorizer supports initial recognition of FMA pattern

Changes to the AArch64 Backend
------------------------------
Expand Down
167 changes: 161 additions & 6 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3883,6 +3883,7 @@ class BoUpSLP {
enum CombinedOpcode {
NotCombinedOp = -1,
MinMax = Instruction::OtherOpsEnd + 1,
FMulAdd,
};
CombinedOpcode CombinedOp = NotCombinedOp;

Expand Down Expand Up @@ -4033,6 +4034,9 @@ class BoUpSLP {
/// Returns true if any scalar in the list is a copyable element.
bool hasCopyableElements() const { return !CopyableElements.empty(); }

/// Returns the state of the operations.
const InstructionsState &getOperations() const { return S; }

/// When ReuseReorderShuffleIndices is empty it just returns position of \p
/// V within vector of Scalars. Otherwise, try to remap on its reuse index.
unsigned findLaneForValue(Value *V) const {
Expand Down Expand Up @@ -11987,6 +11991,84 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
}
}

static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
const InstructionsState &S,
DominatorTree &DT, const DataLayout &DL,
TargetTransformInfo &TTI,
const TargetLibraryInfo &TLI) {
assert(all_of(VL,
[](Value *V) {
return V->getType()->getScalarType()->isFloatingPointTy();
}) &&
"Can only convert to FMA for floating point types");
assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");

auto CheckForContractable = [&](ArrayRef<Value *> VL) {
FastMathFlags FMF;
FMF.set();
for (Value *V : VL) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
continue;
// TODO: support for copyable elements.
Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
continue;
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
FMF &= FPCI->getFastMathFlags();
}
return FMF.allowContract();
};
if (!CheckForContractable(VL))
return InstructionCost::getInvalid();
// fmul also should be contractable
InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);

InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
if (!OpS.valid())
return InstructionCost::getInvalid();
if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
return InstructionCost::getInvalid();
if (!CheckForContractable(Operands.front()))
return InstructionCost::getInvalid();
// Compare the costs.
InstructionCost FMulPlusFaddCost = 0;
InstructionCost FMACost = 0;
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
FastMathFlags FMF;
FMF.set();
for (Value *V : VL) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
continue;
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
FMF &= FPCI->getFastMathFlags();
FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
}
for (auto [V, Op] : zip(VL, Operands.front())) {
auto *I = dyn_cast<Instruction>(Op);
if (!I || !I->hasOneUse()) {
FMACost += TTI.getInstructionCost(cast<Instruction>(V), CostKind);
if (I)
FMACost += TTI.getInstructionCost(I, CostKind);
continue;
}
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
FMF &= FPCI->getFastMathFlags();
FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
}
const unsigned NumOps =
count_if(zip(VL, Operands.front()), [](const auto &P) {
return isa<Instruction>(std::get<0>(P)) &&
isa<Instruction>(std::get<1>(P));
});
Type *Ty = VL.front()->getType();
IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
FMACost += NumOps * TTI.getIntrinsicInstrCost(ICA, CostKind);
return FMACost < FMulPlusFaddCost ? FMACost : InstructionCost::getInvalid();
}

void BoUpSLP::transformNodes() {
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
BaseGraphSize = VectorizableTree.size();
Expand Down Expand Up @@ -12355,6 +12437,25 @@ void BoUpSLP::transformNodes() {
}
break;
}
case Instruction::FSub:
case Instruction::FAdd: {
// Check if possible to convert (a*b)+c to fma.
if (E.State != TreeEntry::Vectorize ||
!E.getOperations().isAddSubLikeOp())
break;
if (!canConvertToFMA(E.Scalars, E.getOperations(), *DT, *DL, *TTI, *TLI)
.isValid())
break;
// This node is a fmuladd node.
E.CombinedOp = TreeEntry::FMulAdd;
TreeEntry *FMulEntry = getOperandEntry(&E, 0);
if (FMulEntry->UserTreeIndex &&
FMulEntry->State == TreeEntry::Vectorize) {
// The FMul node is part of the combined fmuladd node.
FMulEntry->State = TreeEntry::CombinedVectorize;
}
break;
}
default:
break;
}
Expand Down Expand Up @@ -13587,6 +13688,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
return IntrinsicCost;
};
auto GetFMulAddCost = [&, &TTI = *TTI](const InstructionsState &S,
Instruction *VI = nullptr) {
InstructionCost Cost = canConvertToFMA(VI ? ArrayRef<Value *>(VI) : VL, S,
*DT, *DL, TTI, *TLI);
return Cost;
};
switch (ShuffleOrOp) {
case Instruction::PHI: {
// Count reused scalars.
Expand Down Expand Up @@ -13927,6 +14034,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
case TreeEntry::FMulAdd: {
auto GetScalarCost = [&](unsigned Idx) {
if (isa<PoisonValue>(UniqueValues[Idx]))
return InstructionCost(TTI::TCC_Free);
return GetFMulAddCost(E->getOperations(),
cast<Instruction>(UniqueValues[Idx]));
};
auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
FastMathFlags FMF;
FMF.set();
for (Value *V : E->Scalars) {
if (auto *FPCI = dyn_cast<FPMathOperator>(V)) {
FMF &= FPCI->getFastMathFlags();
if (auto *FPCIOp = dyn_cast<FPMathOperator>(FPCI->getOperand(0)))
FMF &= FPCIOp->getFastMathFlags();
}
}
IntrinsicCostAttributes ICA(Intrinsic::fmuladd, VecTy,
{VecTy, VecTy, VecTy}, FMF);
InstructionCost VecCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
return VecCost + CommonCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
case Instruction::FNeg:
case Instruction::Add:
case Instruction::FAdd:
Expand Down Expand Up @@ -13964,8 +14095,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(Op1);
TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(Op2);
return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
Op1Info, Op2Info, Operands);
InstructionCost ScalarCost = TTI->getArithmeticInstrCost(
ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands);
if (auto *I = dyn_cast<Instruction>(UniqueValues[Idx]);
I && (ShuffleOrOp == Instruction::FAdd ||
ShuffleOrOp == Instruction::FSub)) {
InstructionCost IntrinsicCost = GetFMulAddCost(E->getOperations(), I);
if (IntrinsicCost.isValid())
ScalarCost = IntrinsicCost;
}
return ScalarCost;
};
auto GetVectorCost = [=](InstructionCost CommonCost) {
if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) {
Expand Down Expand Up @@ -22593,11 +22732,21 @@ class HorizontalReduction {
/// Try to find a reduction tree.
bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
ScalarEvolution &SE, const DataLayout &DL,
const TargetLibraryInfo &TLI) {
const TargetLibraryInfo &TLI,
DominatorTree &DT, TargetTransformInfo &TTI) {
RdxKind = HorizontalReduction::getRdxKind(Root);
if (!isVectorizable(RdxKind, Root))
return false;

// FMA reduction root - skip.
auto CheckForFMA = [&](Instruction *I) {
return RdxKind == RecurKind::FAdd &&
canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
.isValid();
};
if (CheckForFMA(Root))
return false;

// Analyze "regular" integer/FP types for reductions - no target-specific
// types or pointers.
Type *Ty = Root->getType();
Expand Down Expand Up @@ -22635,7 +22784,7 @@ class HorizontalReduction {
// Also, do not try to reduce const values, if the operation is not
// foldable.
if (!EdgeInst || Level > RecursionMaxDepth ||
getRdxKind(EdgeInst) != RdxKind ||
getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
!hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
!isVectorizable(RdxKind, EdgeInst) ||
Expand Down Expand Up @@ -24204,13 +24353,13 @@ bool SLPVectorizerPass::vectorizeHorReduction(
Stack.emplace(SelectRoot(), 0);
SmallPtrSet<Value *, 8> VisitedInstrs;
bool Res = false;
auto &&TryToReduce = [this, &R](Instruction *Inst) -> Value * {
auto TryToReduce = [this, &R, TTI = TTI](Instruction *Inst) -> Value * {
if (R.isAnalyzedReductionRoot(Inst))
return nullptr;
if (!isReductionCandidate(Inst))
return nullptr;
HorizontalReduction HorRdx;
if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI))
return nullptr;
return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
};
Expand Down Expand Up @@ -24276,6 +24425,12 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {

if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType()))
return false;
// Skip potential FMA candidates.
if ((I->getOpcode() == Instruction::FAdd ||
I->getOpcode() == Instruction::FSub) &&
canConvertToFMA(I, getSameOpcode(I, *TLI), *DT, *DL, *TTI, *TLI)
.isValid())
return false;

Value *P = I->getParent();

Expand Down
34 changes: 20 additions & 14 deletions llvm/test/Transforms/SLPVectorizer/AArch64/commute.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ target triple = "aarch64--linux-gnu"
define void @test1(ptr nocapture readonly %J, i32 %xmin, i32 %ymin) {
; CHECK-LABEL: @test1(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[XMIN:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> [[TMP0]], i32 [[YMIN:%.*]], i32 1
; CHECK-NEXT: br label [[FOR_BODY3_LR_PH:%.*]]
; CHECK: for.body3.lr.ph:
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i32> [[TMP1]] to <2 x float>
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[J:%.*]], align 4
; CHECK-NEXT: [[TMP5:%.*]] = fsub fast <2 x float> [[TMP2]], [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x float> [[TMP5]], [[TMP5]]
; CHECK-NEXT: [[ADD:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP6]])
; CHECK-NEXT: [[CONV5:%.*]] = sitofp i32 [[YMIN:%.*]] to float
; CHECK-NEXT: [[CONV:%.*]] = sitofp i32 [[XMIN:%.*]] to float
; CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[J:%.*]], align 4
; CHECK-NEXT: [[SUB:%.*]] = fsub fast float [[CONV]], [[TMP0]]
; CHECK-NEXT: [[ARRAYIDX9:%.*]] = getelementptr inbounds [[STRUCTA:%.*]], ptr [[J]], i64 0, i32 0, i64 1
; CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[ARRAYIDX9]], align 4
; CHECK-NEXT: [[SUB10:%.*]] = fsub fast float [[CONV5]], [[TMP1]]
; CHECK-NEXT: [[MUL11:%.*]] = fmul fast float [[SUB]], [[SUB]]
; CHECK-NEXT: [[MUL12:%.*]] = fmul fast float [[SUB10]], [[SUB10]]
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[MUL11]], [[MUL12]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[ADD]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY3_LR_PH]], label [[FOR_END27:%.*]]
; CHECK: for.end27:
Expand Down Expand Up @@ -47,15 +50,18 @@ for.end27:
define void @test2(ptr nocapture readonly %J, i32 %xmin, i32 %ymin) {
; CHECK-LABEL: @test2(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[XMIN:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> [[TMP0]], i32 [[YMIN:%.*]], i32 1
; CHECK-NEXT: br label [[FOR_BODY3_LR_PH:%.*]]
; CHECK: for.body3.lr.ph:
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i32> [[TMP1]] to <2 x float>
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[J:%.*]], align 4
; CHECK-NEXT: [[TMP5:%.*]] = fsub fast <2 x float> [[TMP2]], [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x float> [[TMP5]], [[TMP5]]
; CHECK-NEXT: [[ADD:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP6]])
; CHECK-NEXT: [[CONV5:%.*]] = sitofp i32 [[YMIN:%.*]] to float
; CHECK-NEXT: [[CONV:%.*]] = sitofp i32 [[XMIN:%.*]] to float
; CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[J:%.*]], align 4
; CHECK-NEXT: [[SUB:%.*]] = fsub fast float [[CONV]], [[TMP0]]
; CHECK-NEXT: [[ARRAYIDX9:%.*]] = getelementptr inbounds [[STRUCTA:%.*]], ptr [[J]], i64 0, i32 0, i64 1
; CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[ARRAYIDX9]], align 4
; CHECK-NEXT: [[SUB10:%.*]] = fsub fast float [[CONV5]], [[TMP1]]
; CHECK-NEXT: [[MUL11:%.*]] = fmul fast float [[SUB]], [[SUB]]
; CHECK-NEXT: [[MUL12:%.*]] = fmul fast float [[SUB10]], [[SUB10]]
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[MUL12]], [[MUL11]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[ADD]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY3_LR_PH]], label [[FOR_END27:%.*]]
; CHECK: for.end27:
Expand Down
Loading
Loading