@@ -112,6 +112,7 @@ class VectorCombine {
112112 bool foldExtractedCmps (Instruction &I);
113113 bool foldSingleElementStore (Instruction &I);
114114 bool scalarizeLoadExtract (Instruction &I);
115+ bool foldPermuteOfBinops (Instruction &I);
115116 bool foldShuffleOfBinops (Instruction &I);
116117 bool foldShuffleOfCastops (Instruction &I);
117118 bool foldShuffleOfShuffles (Instruction &I);
@@ -1400,6 +1401,93 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
14001401 return true ;
14011402}
14021403
1404+ // / Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1405+ // / --> "binop (shuffle), (shuffle)".
1406+ bool VectorCombine::foldPermuteOfBinops (Instruction &I) {
1407+ BinaryOperator *BinOp;
1408+ ArrayRef<int > OuterMask;
1409+ if (!match (&I,
1410+ m_Shuffle (m_OneUse (m_BinOp (BinOp)), m_Undef (), m_Mask (OuterMask))))
1411+ return false ;
1412+
1413+ // Don't introduce poison into div/rem.
1414+ if (llvm::is_contained (OuterMask, PoisonMaskElem) && BinOp->isIntDivRem ())
1415+ return false ;
1416+
1417+ Value *Op00, *Op01;
1418+ ArrayRef<int > Mask0;
1419+ if (!match (BinOp->getOperand (0 ),
1420+ m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0)))))
1421+ return false ;
1422+
1423+ Value *Op10, *Op11;
1424+ ArrayRef<int > Mask1;
1425+ if (!match (BinOp->getOperand (1 ),
1426+ m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1)))))
1427+ return false ;
1428+
1429+ Instruction::BinaryOps Opcode = BinOp->getOpcode ();
1430+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
1431+ auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType ());
1432+ auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType ());
1433+ auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType ());
1434+ if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1435+ return false ;
1436+
1437+ unsigned NumSrcElts = BinOpTy->getNumElements ();
1438+
1439+ // Don't accept shuffles that reference the second (undef/poison) operand.
1440+ if (any_of (OuterMask, [NumSrcElts](int M) { return M >= (int )NumSrcElts; }))
1441+ return false ;
1442+
1443+ // Merge outer / inner shuffles.
1444+ SmallVector<int > NewMask0, NewMask1;
1445+ for (int M : OuterMask) {
1446+ NewMask0.push_back (M >= 0 ? Mask0[M] : -1 );
1447+ NewMask1.push_back (M >= 0 ? Mask1[M] : -1 );
1448+ }
1449+
1450+ // Try to merge shuffles across the binop if the new shuffles are not costly.
1451+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1452+
1453+ InstructionCost OldCost =
1454+ TTI.getArithmeticInstrCost (Opcode, BinOpTy, CostKind) +
1455+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1456+ OuterMask, CostKind, 0 , nullptr , {BinOp}, &I) +
1457+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1458+ CostKind, 0 , nullptr , {Op00, Op01},
1459+ cast<Instruction>(BinOp->getOperand (0 ))) +
1460+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1461+ CostKind, 0 , nullptr , {Op10, Op11},
1462+ cast<Instruction>(BinOp->getOperand (1 )));
1463+
1464+ InstructionCost NewCost =
1465+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1466+ CostKind, 0 , nullptr , {Op00, Op01}) +
1467+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1468+ CostKind, 0 , nullptr , {Op10, Op11}) +
1469+ TTI.getArithmeticInstrCost (Opcode, ShuffleDstTy, CostKind);
1470+
1471+ LLVM_DEBUG (dbgs () << " Found a shuffle feeding a shuffled binop: " << I
1472+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1473+ << " \n " );
1474+ if (NewCost >= OldCost)
1475+ return false ;
1476+
1477+ Value *Shuf0 = Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1478+ Value *Shuf1 = Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1479+ Value *NewBO = Builder.CreateBinOp (Opcode, Shuf0, Shuf1);
1480+
1481+ // Intersect flags from the old binops.
1482+ if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1483+ NewInst->copyIRFlags (BinOp);
1484+
1485+ Worklist.pushValue (Shuf0);
1486+ Worklist.pushValue (Shuf1);
1487+ replaceValue (I, *NewBO);
1488+ return true ;
1489+ }
1490+
14031491// / Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
14041492bool VectorCombine::foldShuffleOfBinops (Instruction &I) {
14051493 BinaryOperator *B0, *B1;
@@ -2736,6 +2824,7 @@ bool VectorCombine::run() {
27362824 MadeChange |= foldInsExtFNeg (I);
27372825 break ;
27382826 case Instruction::ShuffleVector:
2827+ MadeChange |= foldPermuteOfBinops (I);
27392828 MadeChange |= foldShuffleOfBinops (I);
27402829 MadeChange |= foldShuffleOfCastops (I);
27412830 MadeChange |= foldShuffleOfShuffles (I);
0 commit comments