Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 161 additions & 6 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3759,6 +3759,7 @@ class BoUpSLP {
enum CombinedOpcode {
NotCombinedOp = -1,
MinMax = Instruction::OtherOpsEnd + 1,
FMulAdd,
};
CombinedOpcode CombinedOp = NotCombinedOp;

Expand Down Expand Up @@ -3896,6 +3897,9 @@ class BoUpSLP {

bool hasState() const { return S.valid(); }

/// Returns the state of the operations.
const InstructionsState &getOperations() const { return S; }

/// When ReuseReorderShuffleIndices is empty it just returns position of \p
/// V within vector of Scalars. Otherwise, try to remap on its reuse index.
unsigned findLaneForValue(Value *V) const {
Expand Down Expand Up @@ -11553,6 +11557,84 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
}
}

static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
const InstructionsState &S,
DominatorTree &DT, const DataLayout &DL,
TargetTransformInfo &TTI,
const TargetLibraryInfo &TLI) {
assert(all_of(VL,
[](Value *V) {
return V->getType()->getScalarType()->isFloatingPointTy();
}) &&
"Can only convert to FMA for floating point types");
assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");

auto CheckForContractable = [&](ArrayRef<Value *> VL) {
FastMathFlags FMF;
FMF.set();
for (Value *V : VL) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
continue;
// TODO: support for copyable elements.
Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
continue;
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
FMF &= FPCI->getFastMathFlags();
}
return FMF.allowContract();
};
if (!CheckForContractable(VL))
return InstructionCost::getInvalid();
// fmul also should be contractable
InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);

InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
if (!OpS.valid())
return InstructionCost::getInvalid();
if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
return InstructionCost::getInvalid();
if (!CheckForContractable(Operands.front()))
return InstructionCost::getInvalid();
// Compare the costs.
InstructionCost FMulPlusFaddCost = 0;
InstructionCost FMACost = 0;
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
FastMathFlags FMF;
FMF.set();
for (Value *V : VL) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
continue;
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
FMF &= FPCI->getFastMathFlags();
FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
}
for (auto [V, Op] : zip(VL, Operands.front())) {
auto *I = dyn_cast<Instruction>(Op);
if (!I || !I->hasOneUse()) {
FMACost += TTI.getInstructionCost(cast<Instruction>(V), CostKind);
if (I)
FMACost += TTI.getInstructionCost(I, CostKind);
continue;
}
if (auto *FPCI = dyn_cast<FPMathOperator>(I))
FMF &= FPCI->getFastMathFlags();
FMulPlusFaddCost += TTI.getInstructionCost(I, CostKind);
}
const unsigned NumOps =
count_if(zip(VL, Operands.front()), [](const auto &P) {
return isa<Instruction>(std::get<0>(P)) &&
isa<Instruction>(std::get<1>(P));
});
Type *Ty = VL.front()->getType();
IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you be handling both Intrinsic::fma and Intrinsic::fmuladd?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that fmuladd is correct here. If the TTI sees that it can fold instructions into fma, it will return the same cost as fma. Otherwise, it will return the sum of original fadd+fmul, which allows for comparing the costs correctly.

FMACost += NumOps * TTI.getIntrinsicInstrCost(ICA, CostKind);
return FMACost < FMulPlusFaddCost ? FMACost : InstructionCost::getInvalid();
}

void BoUpSLP::transformNodes() {
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
BaseGraphSize = VectorizableTree.size();
Expand Down Expand Up @@ -11912,6 +11994,25 @@ void BoUpSLP::transformNodes() {
}
break;
}
case Instruction::FSub:
case Instruction::FAdd: {
// Check if possible to convert (a*b)+c to fma.
if (E.State != TreeEntry::Vectorize ||
!E.getOperations().isAddSubLikeOp())
break;
if (!canConvertToFMA(E.Scalars, E.getOperations(), *DT, *DL, *TTI, *TLI)
.isValid())
break;
// This node is a fmuladd node.
E.CombinedOp = TreeEntry::FMulAdd;
TreeEntry *FMulEntry = getOperandEntry(&E, 0);
if (FMulEntry->UserTreeIndex &&
FMulEntry->State == TreeEntry::Vectorize) {
// The FMul node is part of the combined fmuladd node.
FMulEntry->State = TreeEntry::CombinedVectorize;
}
break;
}
default:
break;
}
Expand Down Expand Up @@ -13142,6 +13243,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
return IntrinsicCost;
};
auto GetFMulAddCost = [&, &TTI = *TTI](const InstructionsState &S,
Instruction *VI = nullptr) {
InstructionCost Cost = canConvertToFMA(VI ? ArrayRef<Value *>(VI) : VL, S,
*DT, *DL, TTI, *TLI);
return Cost;
};
switch (ShuffleOrOp) {
case Instruction::PHI: {
// Count reused scalars.
Expand Down Expand Up @@ -13482,6 +13589,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
case TreeEntry::FMulAdd: {
auto GetScalarCost = [&](unsigned Idx) {
if (isa<PoisonValue>(UniqueValues[Idx]))
return InstructionCost(TTI::TCC_Free);
return GetFMulAddCost(E->getOperations(),
cast<Instruction>(UniqueValues[Idx]));
};
auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
FastMathFlags FMF;
FMF.set();
for (Value *V : E->Scalars) {
if (auto *FPCI = dyn_cast<FPMathOperator>(V)) {
FMF &= FPCI->getFastMathFlags();
if (auto *FPCIOp = dyn_cast<FPMathOperator>(FPCI->getOperand(0)))
FMF &= FPCIOp->getFastMathFlags();
}
}
IntrinsicCostAttributes ICA(Intrinsic::fmuladd, VecTy,
{VecTy, VecTy, VecTy}, FMF);
InstructionCost VecCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
return VecCost + CommonCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);
}
case Instruction::FNeg:
case Instruction::Add:
case Instruction::FAdd:
Expand Down Expand Up @@ -13519,8 +13650,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
}
TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(Op1);
TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(Op2);
return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
Op1Info, Op2Info, Operands);
InstructionCost ScalarCost = TTI->getArithmeticInstrCost(
ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands);
if (auto *I = dyn_cast<Instruction>(UniqueValues[Idx]);
I && (ShuffleOrOp == Instruction::FAdd ||
ShuffleOrOp == Instruction::FSub)) {
InstructionCost IntrinsicCost = GetFMulAddCost(E->getOperations(), I);
if (IntrinsicCost.isValid())
ScalarCost = IntrinsicCost;
}
return ScalarCost;
};
auto GetVectorCost = [=](InstructionCost CommonCost) {
if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) {
Expand Down Expand Up @@ -22081,11 +22220,21 @@ class HorizontalReduction {
/// Try to find a reduction tree.
bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
ScalarEvolution &SE, const DataLayout &DL,
const TargetLibraryInfo &TLI) {
const TargetLibraryInfo &TLI,
DominatorTree &DT, TargetTransformInfo &TTI) {
RdxKind = HorizontalReduction::getRdxKind(Root);
if (!isVectorizable(RdxKind, Root))
return false;

// FMA reduction root - skip.
auto CheckForFMA = [&](Instruction *I) {
return RdxKind == RecurKind::FAdd &&
canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
.isValid();
};
if (CheckForFMA(Root))
return false;

// Analyze "regular" integer/FP types for reductions - no target-specific
// types or pointers.
Type *Ty = Root->getType();
Expand Down Expand Up @@ -22123,7 +22272,7 @@ class HorizontalReduction {
// Also, do not try to reduce const values, if the operation is not
// foldable.
if (!EdgeInst || Level > RecursionMaxDepth ||
getRdxKind(EdgeInst) != RdxKind ||
getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
!hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
!isVectorizable(RdxKind, EdgeInst) ||
Expand Down Expand Up @@ -23686,13 +23835,13 @@ bool SLPVectorizerPass::vectorizeHorReduction(
Stack.emplace(SelectRoot(), 0);
SmallPtrSet<Value *, 8> VisitedInstrs;
bool Res = false;
auto &&TryToReduce = [this, &R](Instruction *Inst) -> Value * {
auto TryToReduce = [this, &R, TTI = TTI](Instruction *Inst) -> Value * {
if (R.isAnalyzedReductionRoot(Inst))
return nullptr;
if (!isReductionCandidate(Inst))
return nullptr;
HorizontalReduction HorRdx;
if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI))
return nullptr;
return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
};
Expand Down Expand Up @@ -23758,6 +23907,12 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {

if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType()))
return false;
// Skip potential FMA candidates.
if ((I->getOpcode() == Instruction::FAdd ||
I->getOpcode() == Instruction::FSub) &&
canConvertToFMA(I, getSameOpcode(I, *TLI), *DT, *DL, *TTI, *TLI)
.isValid())
return false;

Value *P = I->getParent();

Expand Down
34 changes: 20 additions & 14 deletions llvm/test/Transforms/SLPVectorizer/AArch64/commute.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ target triple = "aarch64--linux-gnu"
define void @test1(ptr nocapture readonly %J, i32 %xmin, i32 %ymin) {
; CHECK-LABEL: @test1(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[XMIN:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> [[TMP0]], i32 [[YMIN:%.*]], i32 1
; CHECK-NEXT: br label [[FOR_BODY3_LR_PH:%.*]]
; CHECK: for.body3.lr.ph:
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i32> [[TMP1]] to <2 x float>
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[J:%.*]], align 4
; CHECK-NEXT: [[TMP5:%.*]] = fsub fast <2 x float> [[TMP2]], [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x float> [[TMP5]], [[TMP5]]
; CHECK-NEXT: [[ADD:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP6]])
; CHECK-NEXT: [[CONV5:%.*]] = sitofp i32 [[YMIN:%.*]] to float
; CHECK-NEXT: [[CONV:%.*]] = sitofp i32 [[XMIN:%.*]] to float
; CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[J:%.*]], align 4
; CHECK-NEXT: [[SUB:%.*]] = fsub fast float [[CONV]], [[TMP0]]
; CHECK-NEXT: [[ARRAYIDX9:%.*]] = getelementptr inbounds [[STRUCTA:%.*]], ptr [[J]], i64 0, i32 0, i64 1
; CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[ARRAYIDX9]], align 4
; CHECK-NEXT: [[SUB10:%.*]] = fsub fast float [[CONV5]], [[TMP1]]
; CHECK-NEXT: [[MUL11:%.*]] = fmul fast float [[SUB]], [[SUB]]
; CHECK-NEXT: [[MUL12:%.*]] = fmul fast float [[SUB10]], [[SUB10]]
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[MUL11]], [[MUL12]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[ADD]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY3_LR_PH]], label [[FOR_END27:%.*]]
; CHECK: for.end27:
Expand Down Expand Up @@ -47,15 +50,18 @@ for.end27:
define void @test2(ptr nocapture readonly %J, i32 %xmin, i32 %ymin) {
; CHECK-LABEL: @test2(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[XMIN:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> [[TMP0]], i32 [[YMIN:%.*]], i32 1
; CHECK-NEXT: br label [[FOR_BODY3_LR_PH:%.*]]
; CHECK: for.body3.lr.ph:
; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i32> [[TMP1]] to <2 x float>
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[J:%.*]], align 4
; CHECK-NEXT: [[TMP5:%.*]] = fsub fast <2 x float> [[TMP2]], [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x float> [[TMP5]], [[TMP5]]
; CHECK-NEXT: [[ADD:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP6]])
; CHECK-NEXT: [[CONV5:%.*]] = sitofp i32 [[YMIN:%.*]] to float
; CHECK-NEXT: [[CONV:%.*]] = sitofp i32 [[XMIN:%.*]] to float
; CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[J:%.*]], align 4
; CHECK-NEXT: [[SUB:%.*]] = fsub fast float [[CONV]], [[TMP0]]
; CHECK-NEXT: [[ARRAYIDX9:%.*]] = getelementptr inbounds [[STRUCTA:%.*]], ptr [[J]], i64 0, i32 0, i64 1
; CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[ARRAYIDX9]], align 4
; CHECK-NEXT: [[SUB10:%.*]] = fsub fast float [[CONV5]], [[TMP1]]
; CHECK-NEXT: [[MUL11:%.*]] = fmul fast float [[SUB]], [[SUB]]
; CHECK-NEXT: [[MUL12:%.*]] = fmul fast float [[SUB10]], [[SUB10]]
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[MUL12]], [[MUL11]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[ADD]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY3_LR_PH]], label [[FOR_END27:%.*]]
; CHECK: for.end27:
Expand Down
Loading
Loading