Skip to content
69 changes: 69 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class VectorCombine {
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleOfSelects(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldShuffleOfIntrinsics(Instruction &I);
Expand Down Expand Up @@ -1899,6 +1900,73 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
return true;
}

/// Try to convert,
/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
ArrayRef<int> Mask;
Value *C1, *T1, *F1, *C2, *T2, *F2;
if (!match(&I, m_Shuffle(
m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
m_Mask(Mask))))
return false;

auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
return false;

// SelectInsts must have the same FMF.
auto *Select0 = cast<Instruction>(I.getOperand(0));
if (auto *SI0FOp = dyn_cast<FPMathOperator>(Select0))
if (auto *SI1FOp = dyn_cast<FPMathOperator>((I.getOperand(1))))
if (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that both/neither of the selects are FPMathOperator ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, since we're trying to combine two Select statements, I thought they should have the same FMF. do I need to think more about FMFs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably just replace it with:

if (auto *SI0FOp = dyn_cast<FPMathOperator>(Select0))
  if (SI0FOp->getFastMathFlags() != cast<FPMathOperator>((I.getOperand(1)))->getFastMathFlags())
    return false;

return false;

auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
auto SelOp = Instruction::Select;
InstructionCost OldCost = TTI.getCmpSelInstrCost(
SelOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
OldCost += TTI.getCmpSelInstrCost(SelOp, T2->getType(), C2VecTy,
CmpInst::BAD_ICMP_PREDICATE, CostKind);
OldCost += TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr,
{I.getOperand(0), I.getOperand(1)}, &I);

auto *C1C2VecTy = cast<FixedVectorType>(
toVectorTy(Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements()));
InstructionCost NewCost =
TTI.getShuffleCost(SK, C1C2VecTy, Mask, CostKind, 0, nullptr, {C1, C2});
NewCost +=
TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {T1, T2});
NewCost +=
TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr, {F1, F2});
NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, DstVecTy,
CmpInst::BAD_ICMP_PREDICATE, CostKind);

LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");
if (NewCost > OldCost)
return false;

Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
Value *NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);

// We presuppose that the SelectInsts have the same FMF.
if (isa<FPMathOperator>(NewSel))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid similar casts:

if (auto *SIFOp = dyn_cast<FPMathOperator>(NewSel))
  SIFOp->setFastMathFlags(Select0->getFastMathFlags());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic was defective, so I corrected it... sorry abou it.

cast<Instruction>(NewSel)->setFastMathFlags(Select0->getFastMathFlags());

Worklist.pushValue(ShuffleCmp);
Worklist.pushValue(ShuffleTrue);
Worklist.pushValue(ShuffleFalse);
replaceValue(I, *NewSel);
return true;
}

/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
/// into "castop (shuffle)".
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
Expand Down Expand Up @@ -3352,6 +3420,7 @@ bool VectorCombine::run() {
case Instruction::ShuffleVector:
MadeChange |= foldPermuteOfBinops(I);
MadeChange |= foldShuffleOfBinops(I);
MadeChange |= foldShuffleOfSelects(I);
MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldShuffleOfShuffles(I);
MadeChange |= foldShuffleOfIntrinsics(I);
Expand Down
Loading