@@ -113,6 +113,7 @@ class VectorCombine {
113113 bool scalarizeLoadExtract (Instruction &I);
114114 bool foldShuffleOfBinops (Instruction &I);
115115 bool foldShuffleOfCastops (Instruction &I);
116+ bool foldShuffleOfShuffles (Instruction &I);
116117 bool foldShuffleFromReductions (Instruction &I);
117118 bool foldTruncFromReductions (Instruction &I);
118119 bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
@@ -1552,6 +1553,86 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
15521553 return true ;
15531554}
15541555
1556+ // / Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)"
1557+ // / into "shuffle x, y".
1558+ bool VectorCombine::foldShuffleOfShuffles (Instruction &I) {
1559+ Value *V0, *V1;
1560+ UndefValue *U0, *U1;
1561+ ArrayRef<int > OuterMask, InnerMask0, InnerMask1;
1562+ if (!match (&I, m_Shuffle (m_OneUse (m_Shuffle (m_Value (V0), m_UndefValue (U0),
1563+ m_Mask (InnerMask0))),
1564+ m_OneUse (m_Shuffle (m_Value (V1), m_UndefValue (U1),
1565+ m_Mask (InnerMask1))),
1566+ m_Mask (OuterMask))))
1567+ return false ;
1568+
1569+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
1570+ auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType ());
1571+ auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand (0 )->getType ());
1572+ if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
1573+ V0->getType () != V1->getType ())
1574+ return false ;
1575+
1576+ unsigned NumSrcElts = ShuffleSrcTy->getNumElements ();
1577+ unsigned NumImmElts = ShuffleImmTy->getNumElements ();
1578+
1579+ // Bail if either inner masks reference a RHS undef arg.
1580+ if ((!isa<PoisonValue>(U0) &&
1581+ any_of (InnerMask0, [&](int M) { return M >= (int )NumSrcElts; })) ||
1582+ (!isa<PoisonValue>(U1) &&
1583+ any_of (InnerMask1, [&](int M) { return M >= (int )NumSrcElts; })))
1584+ return false ;
1585+
1586+ // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
1587+ SmallVector<int , 16 > NewMask (OuterMask.begin (), OuterMask.end ());
1588+ for (int &M : NewMask) {
1589+ if (0 <= M && M < (int )NumImmElts) {
1590+ M = (InnerMask0[M] >= (int )NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
1591+ } else if (M >= (int )NumImmElts) {
1592+ if (InnerMask1[M - NumImmElts] >= (int )NumSrcElts)
1593+ M = PoisonMaskElem;
1594+ else
1595+ M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
1596+ }
1597+ }
1598+
1599+ // Have we folded to an Identity shuffle?
1600+ if (ShuffleVectorInst::isIdentityMask (NewMask, NumSrcElts)) {
1601+ replaceValue (I, *V0);
1602+ return true ;
1603+ }
1604+
1605+ // Try to merge the shuffles if the new shuffle is not costly.
1606+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1607+
1608+ InstructionCost OldCost =
1609+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
1610+ InnerMask0, CostKind) +
1611+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy,
1612+ InnerMask1, CostKind) +
1613+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy,
1614+ OuterMask, CostKind, 0 , nullptr , std::nullopt , &I);
1615+
1616+ InstructionCost NewCost = TTI.getShuffleCost (
1617+ TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, NewMask, CostKind);
1618+
1619+ LLVM_DEBUG (dbgs () << " Found a shuffle feeding two shuffles: " << I
1620+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1621+ << " \n " );
1622+ if (NewCost > OldCost)
1623+ return false ;
1624+
1625+ // Clear unused sources to poison.
1626+ if (none_of (NewMask, [&](int M) { return 0 <= M && M < (int )NumSrcElts; }))
1627+ V0 = PoisonValue::get (ShuffleSrcTy);
1628+ if (none_of (NewMask, [&](int M) { return (int )NumSrcElts <= M; }))
1629+ V1 = PoisonValue::get (ShuffleSrcTy);
1630+
1631+ Value *Shuf = Builder.CreateShuffleVector (V0, V1, NewMask);
1632+ replaceValue (I, *Shuf);
1633+ return true ;
1634+ }
1635+
15551636// / Given a commutative reduction, the order of the input lanes does not alter
15561637// / the results. We can use this to remove certain shuffles feeding the
15571638// / reduction, removing the need to shuffle at all.
@@ -2107,6 +2188,7 @@ bool VectorCombine::run() {
21072188 case Instruction::ShuffleVector:
21082189 MadeChange |= foldShuffleOfBinops (I);
21092190 MadeChange |= foldShuffleOfCastops (I);
2191+ MadeChange |= foldShuffleOfShuffles (I);
21102192 MadeChange |= foldSelectShuffle (I);
21112193 break ;
21122194 case Instruction::BitCast:
0 commit comments