@@ -1592,17 +1592,21 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
15921592 if (BinOp->isIntDivRem () && llvm::is_contained (OuterMask, PoisonMaskElem))
15931593 return false ;
15941594
1595- Value *Op00, *Op01;
1596- ArrayRef<int > Mask0;
1597- if (!match (BinOp->getOperand (0 ),
1598- m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0)))))
1595+ Value *Op00, *Op01, *Op10, *Op11;
1596+ ArrayRef<int > Mask0, Mask1;
1597+ bool Match0 =
1598+ match (BinOp->getOperand (0 ),
1599+ m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0))));
1600+ bool Match1 =
1601+ match (BinOp->getOperand (1 ),
1602+ m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1))));
1603+ if (!Match0 && !Match1)
15991604 return false ;
16001605
1601- Value *Op10, *Op11;
1602- ArrayRef<int > Mask1;
1603- if (!match (BinOp->getOperand (1 ),
1604- m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1)))))
1605- return false ;
1606+ Op00 = Match0 ? Op00 : BinOp->getOperand (0 );
1607+ Op01 = Match0 ? Op01 : BinOp->getOperand (0 );
1608+ Op10 = Match1 ? Op10 : BinOp->getOperand (1 );
1609+ Op11 = Match1 ? Op11 : BinOp->getOperand (1 );
16061610
16071611 Instruction::BinaryOps Opcode = BinOp->getOpcode ();
16081612 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
@@ -1620,37 +1624,46 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
16201624 any_of (OuterMask, [NumSrcElts](int M) { return M >= (int )NumSrcElts; }))
16211625 return false ;
16221626
1623- // Merge outer / inner shuffles.
1627+ // Merge outer / inner (or identity if no match) shuffles.
16241628 SmallVector<int > NewMask0, NewMask1;
16251629 for (int M : OuterMask) {
16261630 if (M < 0 || M >= (int )NumSrcElts) {
16271631 NewMask0.push_back (PoisonMaskElem);
16281632 NewMask1.push_back (PoisonMaskElem);
16291633 } else {
1630- NewMask0.push_back (Mask0[M]);
1631- NewMask1.push_back (Mask1[M]);
1634+ NewMask0.push_back (Match0 ? Mask0[M] : M );
1635+ NewMask1.push_back (Match1 ? Mask1[M] : M );
16321636 }
16331637 }
16341638
1639+ unsigned NumOpElts = Op0Ty->getNumElements ();
1640+ bool IsIdentity0 = ShuffleVectorInst::isIdentityMask (NewMask0, NumOpElts);
1641+ bool IsIdentity1 = ShuffleVectorInst::isIdentityMask (NewMask1, NumOpElts);
1642+
16351643 // Try to merge shuffles across the binop if the new shuffles are not costly.
16361644 InstructionCost OldCost =
16371645 TTI.getArithmeticInstrCost (Opcode, BinOpTy, CostKind) +
16381646 TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1639- OuterMask, CostKind, 0 , nullptr , {BinOp}, &I) +
1640- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1641- CostKind, 0 , nullptr , {Op00, Op01},
1642- cast<Instruction>(BinOp->getOperand (0 ))) +
1643- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1644- CostKind, 0 , nullptr , {Op10, Op11},
1645- cast<Instruction>(BinOp->getOperand (1 )));
1647+ OuterMask, CostKind, 0 , nullptr , {BinOp}, &I);
1648+ if (Match0)
1649+ OldCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1650+ Mask0, CostKind, 0 , nullptr , {Op00, Op01},
1651+ cast<Instruction>(BinOp->getOperand (0 )));
1652+ if (Match1)
1653+ OldCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1654+ Mask1, CostKind, 0 , nullptr , {Op10, Op11},
1655+ cast<Instruction>(BinOp->getOperand (1 )));
16461656
16471657 InstructionCost NewCost =
1648- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1649- CostKind, 0 , nullptr , {Op00, Op01}) +
1650- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1651- CostKind, 0 , nullptr , {Op10, Op11}) +
16521658 TTI.getArithmeticInstrCost (Opcode, ShuffleDstTy, CostKind);
16531659
1660+ if (!IsIdentity0)
1661+ NewCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1662+ NewMask0, CostKind, 0 , nullptr , {Op00, Op01});
1663+ if (!IsIdentity1)
1664+ NewCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1665+ NewMask1, CostKind, 0 , nullptr , {Op10, Op11});
1666+
16541667 LLVM_DEBUG (dbgs () << " Found a shuffle feeding a shuffled binop: " << I
16551668 << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
16561669 << " \n " );
@@ -1659,16 +1672,18 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
16591672 if (NewCost > OldCost)
16601673 return false ;
16611674
1662- Value *Shuf0 = Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1663- Value *Shuf1 = Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1664- Value *NewBO = Builder.CreateBinOp (Opcode, Shuf0, Shuf1);
1675+ Value *LHS =
1676+ IsIdentity0 ? Op00 : Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1677+ Value *RHS =
1678+ IsIdentity1 ? Op10 : Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1679+ Value *NewBO = Builder.CreateBinOp (Opcode, LHS, RHS);
16651680
16661681 // Intersect flags from the old binops.
16671682 if (auto *NewInst = dyn_cast<Instruction>(NewBO))
16681683 NewInst->copyIRFlags (BinOp);
16691684
1670- Worklist.pushValue (Shuf0 );
1671- Worklist.pushValue (Shuf1 );
1685+ Worklist.pushValue (LHS );
1686+ Worklist.pushValue (RHS );
16721687 replaceValue (I, *NewBO);
16731688 return true ;
16741689}
0 commit comments