@@ -112,6 +112,7 @@ class VectorCombine {
112112 bool foldExtractedCmps (Instruction &I);
113113 bool foldSingleElementStore (Instruction &I);
114114 bool scalarizeLoadExtract (Instruction &I);
115+ bool foldPermuteOfBinops (Instruction &I);
115116 bool foldShuffleOfBinops (Instruction &I);
116117 bool foldShuffleOfCastops (Instruction &I);
117118 bool foldShuffleOfShuffles (Instruction &I);
@@ -1400,6 +1401,100 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
14001401 return true ;
14011402}
14021403
1404+ // / Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1405+ // / --> "binop (shuffle), (shuffle)".
1406+ bool VectorCombine::foldPermuteOfBinops (Instruction &I) {
1407+ BinaryOperator *BinOp;
1408+ ArrayRef<int > OuterMask;
1409+ if (!match (&I,
1410+ m_Shuffle (m_OneUse (m_BinOp (BinOp)), m_Undef (), m_Mask (OuterMask))))
1411+ return false ;
1412+
1413+ // Don't introduce poison into div/rem.
1414+ if (BinOp->isIntDivRem () && llvm::is_contained (OuterMask, PoisonMaskElem))
1415+ return false ;
1416+
1417+ Value *Op00, *Op01;
1418+ ArrayRef<int > Mask0;
1419+ if (!match (BinOp->getOperand (0 ),
1420+ m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0)))))
1421+ return false ;
1422+
1423+ Value *Op10, *Op11;
1424+ ArrayRef<int > Mask1;
1425+ if (!match (BinOp->getOperand (1 ),
1426+ m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1)))))
1427+ return false ;
1428+
1429+ Instruction::BinaryOps Opcode = BinOp->getOpcode ();
1430+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
1431+ auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType ());
1432+ auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType ());
1433+ auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType ());
1434+ if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1435+ return false ;
1436+
1437+ unsigned NumSrcElts = BinOpTy->getNumElements ();
1438+
1439+ // Don't accept shuffles that reference the second (undef/poison) operand in
1440+ // div/rem..
1441+ if (BinOp->isIntDivRem () &&
1442+ any_of (OuterMask, [NumSrcElts](int M) { return M >= (int )NumSrcElts; }))
1443+ return false ;
1444+
1445+ // Merge outer / inner shuffles.
1446+ SmallVector<int > NewMask0, NewMask1;
1447+ for (int M : OuterMask) {
1448+ if (M < 0 || M >= (int )NumSrcElts) {
1449+ NewMask0.push_back (PoisonMaskElem);
1450+ NewMask1.push_back (PoisonMaskElem);
1451+ } else {
1452+ NewMask0.push_back (Mask0[M]);
1453+ NewMask1.push_back (Mask1[M]);
1454+ }
1455+ }
1456+
1457+ // Try to merge shuffles across the binop if the new shuffles are not costly.
1458+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1459+
1460+ InstructionCost OldCost =
1461+ TTI.getArithmeticInstrCost (Opcode, BinOpTy, CostKind) +
1462+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1463+ OuterMask, CostKind, 0 , nullptr , {BinOp}, &I) +
1464+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1465+ CostKind, 0 , nullptr , {Op00, Op01},
1466+ cast<Instruction>(BinOp->getOperand (0 ))) +
1467+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1468+ CostKind, 0 , nullptr , {Op10, Op11},
1469+ cast<Instruction>(BinOp->getOperand (1 )));
1470+
1471+ InstructionCost NewCost =
1472+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1473+ CostKind, 0 , nullptr , {Op00, Op01}) +
1474+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1475+ CostKind, 0 , nullptr , {Op10, Op11}) +
1476+ TTI.getArithmeticInstrCost (Opcode, ShuffleDstTy, CostKind);
1477+
1478+ LLVM_DEBUG (dbgs () << " Found a shuffle feeding a shuffled binop: " << I
1479+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1480+ << " \n " );
1481+ if (NewCost >= OldCost)
1482+ return false ;
1483+
1484+ Value *Shuf0 = Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1485+ Value *Shuf1 = Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1486+ Value *NewBO = Builder.CreateBinOp (Opcode, Shuf0, Shuf1);
1487+
1488+ // Intersect flags from the old binops.
1489+ if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1490+ NewInst->copyIRFlags (BinOp);
1491+
1492+ Worklist.pushValue (Shuf0);
1493+ Worklist.pushValue (Shuf1);
1494+ replaceValue (I, *NewBO);
1495+ return true ;
1496+ }
1497+
14031498// / Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
14041499bool VectorCombine::foldShuffleOfBinops (Instruction &I) {
14051500 BinaryOperator *B0, *B1;
@@ -2736,6 +2831,7 @@ bool VectorCombine::run() {
27362831 MadeChange |= foldInsExtFNeg (I);
27372832 break ;
27382833 case Instruction::ShuffleVector:
2834+ MadeChange |= foldPermuteOfBinops (I);
27392835 MadeChange |= foldShuffleOfBinops (I);
27402836 MadeChange |= foldShuffleOfCastops (I);
27412837 MadeChange |= foldShuffleOfShuffles (I);
0 commit comments