@@ -112,6 +112,7 @@ class VectorCombine {
112112 bool foldExtractedCmps (Instruction &I);
113113 bool foldSingleElementStore (Instruction &I);
114114 bool scalarizeLoadExtract (Instruction &I);
115+ bool foldPermuteOfBinops (Instruction &I);
115116 bool foldShuffleOfBinops (Instruction &I);
116117 bool foldShuffleOfCastops (Instruction &I);
117118 bool foldShuffleOfShuffles (Instruction &I);
@@ -1400,6 +1401,92 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
14001401 return true ;
14011402}
14021403
1404+ // / Try to convert "shuffle (binop (shuffle, shuffle)), undef" into "binop (shuffle), (shuffle)".
1405+ bool VectorCombine::foldPermuteOfBinops (Instruction &I) {
1406+ BinaryOperator *BinOp;
1407+ ArrayRef<int > OuterMask;
1408+ if (!match (&I,
1409+ m_Shuffle (m_OneUse (m_BinOp (BinOp)), m_Undef (), m_Mask (OuterMask))))
1410+ return false ;
1411+
1412+ // Don't introduce poison into div/rem.
1413+ if (llvm::is_contained (OuterMask, PoisonMaskElem) && BinOp->isIntDivRem ())
1414+ return false ;
1415+
1416+ Value *Op00, *Op01;
1417+ ArrayRef<int > Mask0;
1418+ if (!match (BinOp->getOperand (0 ),
1419+ m_OneUse (m_Shuffle (m_Value (Op00), m_Value (Op01), m_Mask (Mask0)))))
1420+ return false ;
1421+
1422+ Value *Op10, *Op11;
1423+ ArrayRef<int > Mask1;
1424+ if (!match (BinOp->getOperand (1 ),
1425+ m_OneUse (m_Shuffle (m_Value (Op10), m_Value (Op11), m_Mask (Mask1)))))
1426+ return false ;
1427+
1428+ Instruction::BinaryOps Opcode = BinOp->getOpcode ();
1429+ auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType ());
1430+ auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType ());
1431+ auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType ());
1432+ auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType ());
1433+ if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1434+ return false ;
1435+
1436+ unsigned NumSrcElts = BinOpTy->getNumElements ();
1437+
1438+ // Don't accept shuffles that reference the second (undef/poison) operand.
1439+ if (any_of (OuterMask, [NumSrcElts](int M) { return M >= (int )NumSrcElts; }))
1440+ return false ;
1441+
1442+ // Merge outer / inner shuffles.
1443+ SmallVector<int > NewMask0, NewMask1;
1444+ for (int M : OuterMask) {
1445+ NewMask0.push_back (M >= 0 ? Mask0[M] : -1 );
1446+ NewMask1.push_back (M >= 0 ? Mask1[M] : -1 );
1447+ }
1448+
1449+ // Try to merge shuffles across the binop if the new shuffles are not costly.
1450+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1451+
1452+ InstructionCost OldCost =
1453+ TTI.getArithmeticInstrCost (Opcode, BinOpTy, CostKind) +
1454+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1455+ OuterMask, CostKind, 0 , nullptr , {BinOp}, &I) +
1456+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1457+ CostKind, 0 , nullptr , {Op00, Op01},
1458+ cast<Instruction>(BinOp->getOperand (0 ))) +
1459+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1460+ CostKind, 0 , nullptr , {Op10, Op11},
1461+ cast<Instruction>(BinOp->getOperand (1 )));
1462+
1463+ InstructionCost NewCost =
1464+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty,
1465+ NewMask0, CostKind, 0 , nullptr , {Op00, Op01}) +
1466+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty,
1467+ NewMask1, CostKind, 0 , nullptr , {Op10, Op11}) +
1468+ TTI.getArithmeticInstrCost (Opcode, ShuffleDstTy, CostKind);
1469+
1470+ LLVM_DEBUG (dbgs () << " Found a shuffle feeding a shuffled binop: " << I
1471+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1472+ << " \n " );
1473+ if (NewCost >= OldCost)
1474+ return false ;
1475+
1476+ Value *Shuf0 = Builder.CreateShuffleVector (Op00, Op01, NewMask0);
1477+ Value *Shuf1 = Builder.CreateShuffleVector (Op10, Op11, NewMask1);
1478+ Value *NewBO = Builder.CreateBinOp (Opcode, Shuf0, Shuf1);
1479+
1480+ // Intersect flags from the old binops.
1481+ if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1482+ NewInst->copyIRFlags (BinOp);
1483+
1484+ Worklist.pushValue (Shuf0);
1485+ Worklist.pushValue (Shuf1);
1486+ replaceValue (I, *NewBO);
1487+ return true ;
1488+ }
1489+
14031490// / Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
14041491bool VectorCombine::foldShuffleOfBinops (Instruction &I) {
14051492 BinaryOperator *B0, *B1;
@@ -2736,6 +2823,7 @@ bool VectorCombine::run() {
27362823 MadeChange |= foldInsExtFNeg (I);
27372824 break ;
27382825 case Instruction::ShuffleVector:
2826+ MadeChange |= foldPermuteOfBinops (I);
27392827 MadeChange |= foldShuffleOfBinops (I);
27402828 MadeChange |= foldShuffleOfCastops (I);
27412829 MadeChange |= foldShuffleOfShuffles (I);
0 commit comments