@@ -114,6 +114,7 @@ class VectorCombine {
114114 bool foldShuffleOfBinops (Instruction &I);
115115 bool foldShuffleOfCastops (Instruction &I);
116116 bool foldShuffleOfShuffles (Instruction &I);
117+ bool foldShuffleToIdentity (Instruction &I);
117118 bool foldShuffleFromReductions (Instruction &I);
118119 bool foldTruncFromReductions (Instruction &I);
119120 bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
@@ -1667,6 +1668,148 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
16671668 return true ;
16681669}
16691670
1671+ // Starting from a shuffle, look up through operands tracking the shuffled index
1672+ // of each lane. If we can simplify away the shuffles to identities then
1673+ // do so.
1674+ bool VectorCombine::foldShuffleToIdentity (Instruction &I) {
1675+ FixedVectorType *Ty = dyn_cast<FixedVectorType>(I.getType ());
1676+ if (!Ty || !isa<Instruction>(I.getOperand (0 )) ||
1677+ !isa<Instruction>(I.getOperand (1 )))
1678+ return false ;
1679+
1680+ using InstLane = std::pair<Value *, int >;
1681+
1682+ auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
1683+ while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
1684+ unsigned NumElts =
1685+ cast<FixedVectorType>(SV->getOperand (0 )->getType ())->getNumElements ();
1686+ int M = SV->getMaskValue (Lane);
1687+ if (M < 0 )
1688+ return {nullptr , -1 };
1689+ else if (M < (int )NumElts) {
1690+ V = SV->getOperand (0 );
1691+ Lane = M;
1692+ } else {
1693+ V = SV->getOperand (1 );
1694+ Lane = M - NumElts;
1695+ }
1696+ }
1697+ return InstLane{V, Lane};
1698+ };
1699+
1700+ auto GenerateInstLaneVectorFromOperand =
1701+ [&LookThroughShuffles](const SmallVector<InstLane> &Item, int Op) {
1702+ SmallVector<InstLane> NItem;
1703+ for (InstLane V : Item) {
1704+ NItem.emplace_back (
1705+ !V.first
1706+ ? InstLane{nullptr , -1 }
1707+ : LookThroughShuffles (
1708+ cast<Instruction>(V.first )->getOperand (Op), V.second ));
1709+ }
1710+ return NItem;
1711+ };
1712+
1713+ SmallVector<InstLane> Start;
1714+ for (unsigned M = 0 ; M < Ty->getNumElements (); ++M)
1715+ Start.push_back (LookThroughShuffles (&I, M));
1716+
1717+ SmallVector<SmallVector<InstLane>> Worklist;
1718+ Worklist.push_back (Start);
1719+ SmallPtrSet<Value *, 4 > IdentityLeafs, SplatLeafs;
1720+ unsigned NumVisited = 0 ;
1721+
1722+ while (!Worklist.empty ()) {
1723+ SmallVector<InstLane> Item = Worklist.pop_back_val ();
1724+ if (++NumVisited > MaxInstrsToScan)
1725+ return false ;
1726+
1727+ // If we found an undef first lane then bail out to keep things simple.
1728+ if (!Item[0 ].first )
1729+ return false ;
1730+
1731+ // Look for an identity value.
1732+ if (Item[0 ].second == 0 && Item[0 ].first ->getType () == Ty &&
1733+ all_of (drop_begin (enumerate(Item)), [&](const auto &E) {
1734+ return !E.value ().first || (E.value ().first == Item[0 ].first &&
1735+ E.value ().second == (int )E.index ());
1736+ })) {
1737+ IdentityLeafs.insert (Item[0 ].first );
1738+ continue ;
1739+ }
1740+ // Look for a splat value.
1741+ if (all_of (drop_begin (Item), [&](InstLane &IL) {
1742+ return !IL.first ||
1743+ (IL.first == Item[0 ].first && IL.second == Item[0 ].second );
1744+ })) {
1745+ SplatLeafs.insert (Item[0 ].first );
1746+ continue ;
1747+ }
1748+
1749+ // We need each element to be the same type of value, and check that each
1750+ // element has a single use.
1751+ if (!all_of (drop_begin (Item), [&](InstLane IL) {
1752+ if (!IL.first )
1753+ return true ;
1754+ if (isa<Instruction>(IL.first ) &&
1755+ !cast<Instruction>(IL.first )->hasOneUse ())
1756+ return false ;
1757+ return IL.first ->getValueID () == Item[0 ].first ->getValueID () &&
1758+ (!isa<IntrinsicInst>(IL.first ) ||
1759+ cast<IntrinsicInst>(IL.first )->getIntrinsicID () ==
1760+ cast<IntrinsicInst>(Item[0 ].first )->getIntrinsicID ());
1761+ }))
1762+ return false ;
1763+
1764+ // Check the operator is one that we support.
1765+ if (isa<BinaryOperator>(Item[0 ].first )) {
1766+ Worklist.push_back (GenerateInstLaneVectorFromOperand (Item, 0 ));
1767+ Worklist.push_back (GenerateInstLaneVectorFromOperand (Item, 1 ));
1768+ } else if (isa<UnaryOperator>(Item[0 ].first )) {
1769+ Worklist.push_back (GenerateInstLaneVectorFromOperand (Item, 0 ));
1770+ } else {
1771+ return false ;
1772+ }
1773+ }
1774+
1775+ // If we got this far, we know the shuffles are superfluous and can be
1776+ // removed. Scan through again and generate the new tree of instructions.
1777+ std::function<Value *(const SmallVector<InstLane> &)> generate =
1778+ [&](const SmallVector<InstLane> &Item) -> Value * {
1779+ if (IdentityLeafs.contains (Item[0 ].first ) &&
1780+ all_of (drop_begin (enumerate(Item)), [&](const auto &E) {
1781+ return !E.value ().first || (E.value ().first == Item[0 ].first &&
1782+ E.value ().second == (int )E.index ());
1783+ })) {
1784+ return Item[0 ].first ;
1785+ } else if (SplatLeafs.contains (Item[0 ].first )) {
1786+ if (auto ILI = dyn_cast<Instruction>(Item[0 ].first ))
1787+ Builder.SetInsertPoint (*ILI->getInsertionPointAfterDef ());
1788+ else if (isa<Argument>(Item[0 ].first ))
1789+ Builder.SetInsertPointPastAllocas (I.getParent ()->getParent ());
1790+ SmallVector<int , 16 > Mask (Ty->getNumElements (), Item[0 ].second );
1791+ return Builder.CreateShuffleVector (Item[0 ].first , Mask);
1792+ }
1793+
1794+ auto *I = cast<Instruction>(Item[0 ].first );
1795+ SmallVector<Value *> Ops;
1796+ unsigned E = I->getNumOperands ();
1797+ for (unsigned Idx = 0 ; Idx < E; Idx++)
1798+ Ops.push_back (generate (GenerateInstLaneVectorFromOperand (Item, Idx)));
1799+ Builder.SetInsertPoint (I);
1800+ if (auto BI = dyn_cast<BinaryOperator>(I))
1801+ return Builder.CreateBinOp ((Instruction::BinaryOps)BI->getOpcode (),
1802+ Ops[0 ], Ops[1 ]);
1803+ if (auto UI = dyn_cast<UnaryOperator>(I))
1804+ return Builder.CreateUnOp ((Instruction::UnaryOps)UI->getOpcode (), Ops[0 ]);
1805+ llvm_unreachable (" Unhandled instruction in generate" );
1806+ };
1807+
1808+ Value *V = generate (Start);
1809+ replaceValue (I, *V);
1810+ return true ;
1811+ }
1812+
16701813// / Given a commutative reduction, the order of the input lanes does not alter
16711814// / the results. We can use this to remove certain shuffles feeding the
16721815// / reduction, removing the need to shuffle at all.
@@ -2224,6 +2367,7 @@ bool VectorCombine::run() {
22242367 MadeChange |= foldShuffleOfCastops (I);
22252368 MadeChange |= foldShuffleOfShuffles (I);
22262369 MadeChange |= foldSelectShuffle (I);
2370+ MadeChange |= foldShuffleToIdentity (I);
22272371 break ;
22282372 case Instruction::BitCast:
22292373 MadeChange |= foldBitcastShuffle (I);
0 commit comments