@@ -115,6 +115,7 @@ class VectorCombine {
115115 bool foldExtractedCmps (Instruction &I);
116116 bool foldSingleElementStore (Instruction &I);
117117 bool scalarizeLoadExtract (Instruction &I);
118+ bool foldConcatOfBoolMasks (Instruction &I);
118119 bool foldPermuteOfBinops (Instruction &I);
119120 bool foldShuffleOfBinops (Instruction &I);
120121 bool foldShuffleOfCastops (Instruction &I);
@@ -1423,6 +1424,113 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
14231424 return true ;
14241425}
14251426
1427+ // / Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1428+ // / to "(bitcast (concat X, Y))"
1429+ // / where X/Y are bitcasted from i1 mask vectors.
1430+ bool VectorCombine::foldConcatOfBoolMasks (Instruction &I) {
1431+ Type *Ty = I.getType ();
1432+ if (!Ty->isIntegerTy ())
1433+ return false ;
1434+
1435+ // TODO: Add big endian test coverage
1436+ if (DL->isBigEndian ())
1437+ return false ;
1438+
1439+ // Restrict to disjoint cases so the mask vectors aren't overlapping.
1440+ Instruction *X, *Y;
1441+ if (!match (&I, m_DisjointOr (m_Instruction (X), m_Instruction (Y))))
1442+ return false ;
1443+
1444+ // Allow both sources to contain shl, to handle more generic pattern:
1445+ // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1446+ Value *SrcX;
1447+ uint64_t ShAmtX = 0 ;
1448+ if (!match (X, m_OneUse (m_ZExt (m_OneUse (m_BitCast (m_Value (SrcX)))))) &&
1449+ !match (X, m_OneUse (
1450+ m_Shl (m_OneUse (m_ZExt (m_OneUse (m_BitCast (m_Value (SrcX))))),
1451+ m_ConstantInt (ShAmtX)))))
1452+ return false ;
1453+
1454+ Value *SrcY;
1455+ uint64_t ShAmtY = 0 ;
1456+ if (!match (Y, m_OneUse (m_ZExt (m_OneUse (m_BitCast (m_Value (SrcY)))))) &&
1457+ !match (Y, m_OneUse (
1458+ m_Shl (m_OneUse (m_ZExt (m_OneUse (m_BitCast (m_Value (SrcY))))),
1459+ m_ConstantInt (ShAmtY)))))
1460+ return false ;
1461+
1462+ // Canonicalize larger shift to the RHS.
1463+ if (ShAmtX > ShAmtY) {
1464+ std::swap (X, Y);
1465+ std::swap (SrcX, SrcY);
1466+ std::swap (ShAmtX, ShAmtY);
1467+ }
1468+
1469+ // Ensure both sources are matching vXi1 bool mask types, and that the shift
1470+ // difference is the mask width so they can be easily concatenated together.
1471+ uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1472+ unsigned NumSHL = (ShAmtX > 0 ) + (ShAmtY > 0 );
1473+ unsigned BitWidth = Ty->getPrimitiveSizeInBits ();
1474+ auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType ());
1475+ if (!MaskTy || SrcX->getType () != SrcY->getType () ||
1476+ !MaskTy->getElementType ()->isIntegerTy (1 ) ||
1477+ MaskTy->getNumElements () != ShAmtDiff ||
1478+ MaskTy->getNumElements () > (BitWidth / 2 ))
1479+ return false ;
1480+
1481+ auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType (MaskTy);
1482+ auto *ConcatIntTy =
1483+ Type::getIntNTy (Ty->getContext (), ConcatTy->getNumElements ());
1484+ auto *MaskIntTy = Type::getIntNTy (Ty->getContext (), ShAmtDiff);
1485+
1486+ SmallVector<int , 32 > ConcatMask (ConcatTy->getNumElements ());
1487+ std::iota (ConcatMask.begin (), ConcatMask.end (), 0 );
1488+
1489+ // TODO: Is it worth supporting multi use cases?
1490+ InstructionCost OldCost = 0 ;
1491+ OldCost += TTI.getArithmeticInstrCost (Instruction::Or, Ty, CostKind);
1492+ OldCost +=
1493+ NumSHL * TTI.getArithmeticInstrCost (Instruction::Shl, Ty, CostKind);
1494+ OldCost += 2 * TTI.getCastInstrCost (Instruction::ZExt, Ty, MaskIntTy,
1495+ TTI::CastContextHint::None, CostKind);
1496+ OldCost += 2 * TTI.getCastInstrCost (Instruction::BitCast, MaskIntTy, MaskTy,
1497+ TTI::CastContextHint::None, CostKind);
1498+
1499+ InstructionCost NewCost = 0 ;
1500+ NewCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
1501+ ConcatMask, CostKind);
1502+ NewCost += TTI.getCastInstrCost (Instruction::BitCast, ConcatIntTy, ConcatTy,
1503+ TTI::CastContextHint::None, CostKind);
1504+ if (Ty != ConcatIntTy)
1505+ NewCost += TTI.getCastInstrCost (Instruction::ZExt, Ty, ConcatIntTy,
1506+ TTI::CastContextHint::None, CostKind);
1507+ if (ShAmtX > 0 )
1508+ NewCost += TTI.getArithmeticInstrCost (Instruction::Shl, Ty, CostKind);
1509+
1510+ if (NewCost > OldCost)
1511+ return false ;
1512+
1513+ // Build bool mask concatenation, bitcast back to scalar integer, and perform
1514+ // any residual zero-extension or shifting.
1515+ Value *Concat = Builder.CreateShuffleVector (SrcX, SrcY, ConcatMask);
1516+ Worklist.pushValue (Concat);
1517+
1518+ Value *Result = Builder.CreateBitCast (Concat, ConcatIntTy);
1519+
1520+ if (Ty != ConcatIntTy) {
1521+ Worklist.pushValue (Result);
1522+ Result = Builder.CreateZExt (Result, Ty);
1523+ }
1524+
1525+ if (ShAmtX > 0 ) {
1526+ Worklist.pushValue (Result);
1527+ Result = Builder.CreateShl (Result, ShAmtX);
1528+ }
1529+
1530+ replaceValue (I, *Result);
1531+ return true ;
1532+ }
1533+
14261534// / Try to convert "shuffle (binop (shuffle, shuffle)), undef"
14271535// / --> "binop (shuffle), (shuffle)".
14281536bool VectorCombine::foldPermuteOfBinops (Instruction &I) {
@@ -2908,6 +3016,9 @@ bool VectorCombine::run() {
29083016 if (TryEarlyFoldsOnly)
29093017 return ;
29103018
3019+ if (I.getType ()->isIntegerTy ())
3020+ MadeChange |= foldConcatOfBoolMasks (I);
3021+
29113022 // Otherwise, try folds that improve codegen but may interfere with
29123023 // early IR canonicalizations.
29133024 // The type checking is for run-time efficiency. We can avoid wasting time
0 commit comments