Skip to content
72 changes: 72 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class VectorCombine {
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleOfSelects(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldShuffleOfIntrinsics(Instruction &I);
Expand Down Expand Up @@ -1899,6 +1900,76 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
return true;
}

/// Try to convert,
/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
ArrayRef<int> Mask;
Value *C1, *T1, *F1, *C2, *T2, *F2;
if (!match(&I, m_Shuffle(
m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
m_Mask(Mask))))
return false;

auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
return false;

auto *SI0FOp = dyn_cast<FPMathOperator>(I.getOperand(0));
auto *SI1FOp = dyn_cast<FPMathOperator>(I.getOperand(1));
// SelectInsts must have the same FMF.
if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
((SI0FOp != nullptr) &&
(SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
return false;

auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
auto SelOp = Instruction::Select;
InstructionCost OldCost = TTI.getCmpSelInstrCost(
SelOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
OldCost += TTI.getCmpSelInstrCost(SelOp, T2->getType(), C2VecTy,
CmpInst::BAD_ICMP_PREDICATE, CostKind);
OldCost += TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr,
{I.getOperand(0), I.getOperand(1)}, &I);

auto *C1C2VecTy = cast<FixedVectorType>(
toVectorTy(Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements()));
InstructionCost NewCost =
TTI.getShuffleCost(SK, C1C2VecTy, Mask, CostKind, 0, nullptr, {C1, C2});
NewCost +=
TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {T1, T2});
NewCost +=
TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {F1, F2});
NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, DstVecTy,
CmpInst::BAD_ICMP_PREDICATE, CostKind);

LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");
if (NewCost > OldCost)
return false;

Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
Value *NewSel;
// We presuppose that the SelectInsts have the same FMF.
if (SI0FOp)
NewSel = Builder.CreateSelectFMF(ShuffleCmp, ShuffleTrue, ShuffleFalse,
SI0FOp->getFastMathFlags());
else
NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);

Worklist.pushValue(ShuffleCmp);
Worklist.pushValue(ShuffleTrue);
Worklist.pushValue(ShuffleFalse);
replaceValue(I, *NewSel);
return true;
}

/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
/// into "castop (shuffle)".
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
Expand Down Expand Up @@ -3352,6 +3423,7 @@ bool VectorCombine::run() {
case Instruction::ShuffleVector:
MadeChange |= foldPermuteOfBinops(I);
MadeChange |= foldShuffleOfBinops(I);
MadeChange |= foldShuffleOfSelects(I);
MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldShuffleOfShuffles(I);
MadeChange |= foldShuffleOfIntrinsics(I);
Expand Down
Loading