@@ -3549,6 +3549,109 @@ static Value *foldOrOfInversions(BinaryOperator &I,
3549
3549
return nullptr ;
3550
3550
}
3551
3551
3552
+ // / Match \p V as "lshr -> mask -> zext -> shl".
3553
+ // /
3554
+ // / \p Int is the underlying integer being extracted from.
3555
+ // / \p Mask is a bitmask identifying which bits of the integer are being
3556
+ // / extracted. \p Offset identifies which bit of the result \p V corresponds to
3557
+ // / the least significant bit of \p Int
3558
+ static bool matchZExtedSubInteger (Value *V, Value *&Int, APInt &Mask,
3559
+ uint64_t &Offset, bool &IsShlNUW,
3560
+ bool &IsShlNSW) {
3561
+ Value *ShlOp0;
3562
+ uint64_t ShlAmt = 0 ;
3563
+ if (!match (V, m_OneUse (m_Shl (m_Value (ShlOp0), m_ConstantInt (ShlAmt)))))
3564
+ return false ;
3565
+
3566
+ IsShlNUW = cast<BinaryOperator>(V)->hasNoUnsignedWrap ();
3567
+ IsShlNSW = cast<BinaryOperator>(V)->hasNoSignedWrap ();
3568
+
3569
+ Value *ZExtOp0;
3570
+ if (!match (ShlOp0, m_OneUse (m_ZExt (m_Value (ZExtOp0)))))
3571
+ return false ;
3572
+
3573
+ Value *MaskedOp0;
3574
+ const APInt *ShiftedMaskConst = nullptr ;
3575
+ if (!match (ZExtOp0, m_CombineOr (m_OneUse (m_And (m_Value (MaskedOp0),
3576
+ m_APInt (ShiftedMaskConst))),
3577
+ m_Value (MaskedOp0))))
3578
+ return false ;
3579
+
3580
+ uint64_t LShrAmt = 0 ;
3581
+ if (!match (MaskedOp0,
3582
+ m_CombineOr (m_OneUse (m_LShr (m_Value (Int), m_ConstantInt (LShrAmt))),
3583
+ m_Value (Int))))
3584
+ return false ;
3585
+
3586
+ if (LShrAmt > ShlAmt)
3587
+ return false ;
3588
+ Offset = ShlAmt - LShrAmt;
3589
+
3590
+ Mask = ShiftedMaskConst ? ShiftedMaskConst->shl (LShrAmt)
3591
+ : APInt::getBitsSetFrom (
3592
+ Int->getType ()->getScalarSizeInBits (), LShrAmt);
3593
+
3594
+ return true ;
3595
+ }
3596
+
3597
+ // / Try to fold the join of two scalar integers whose bits are unpacked and
3598
+ // / zexted from the same source integer.
3599
+ static Value *foldIntegerRepackThroughZExt (Value *Lhs, Value *Rhs,
3600
+ InstCombiner::BuilderTy &Builder) {
3601
+
3602
+ Value *LhsInt, *RhsInt;
3603
+ APInt LhsMask, RhsMask;
3604
+ uint64_t LhsOffset, RhsOffset;
3605
+ bool IsLhsShlNUW, IsLhsShlNSW, IsRhsShlNUW, IsRhsShlNSW;
3606
+ if (!matchZExtedSubInteger (Lhs, LhsInt, LhsMask, LhsOffset, IsLhsShlNUW,
3607
+ IsLhsShlNSW))
3608
+ return nullptr ;
3609
+ if (!matchZExtedSubInteger (Rhs, RhsInt, RhsMask, RhsOffset, IsRhsShlNUW,
3610
+ IsRhsShlNSW))
3611
+ return nullptr ;
3612
+ if (LhsInt != RhsInt || LhsOffset != RhsOffset)
3613
+ return nullptr ;
3614
+
3615
+ APInt Mask = LhsMask | RhsMask;
3616
+
3617
+ Type *DestTy = Lhs->getType ();
3618
+ Value *Res = Builder.CreateShl (
3619
+ Builder.CreateZExt (
3620
+ Builder.CreateAnd (LhsInt, Mask, LhsInt->getName () + " .mask" ), DestTy,
3621
+ LhsInt->getName () + " .zext" ),
3622
+ ConstantInt::get (DestTy, LhsOffset), " " , IsLhsShlNUW && IsRhsShlNUW,
3623
+ IsLhsShlNSW && IsRhsShlNSW);
3624
+ Res->takeName (Lhs);
3625
+ return Res;
3626
+ }
3627
+
3628
+ Value *InstCombinerImpl::foldDisjointOr (Value *LHS, Value *RHS) {
3629
+ if (Value *Res = foldIntegerRepackThroughZExt (LHS, RHS, Builder))
3630
+ return Res;
3631
+
3632
+ return nullptr ;
3633
+ }
3634
+
3635
+ Value *InstCombinerImpl::reassociateDisjointOr (Value *LHS, Value *RHS) {
3636
+
3637
+ Value *X, *Y;
3638
+ if (match (RHS, m_OneUse (m_DisjointOr (m_Value (X), m_Value (Y))))) {
3639
+ if (Value *Res = foldDisjointOr (LHS, X))
3640
+ return Builder.CreateOr (Res, Y, " " , /* IsDisjoint=*/ true );
3641
+ if (Value *Res = foldDisjointOr (LHS, Y))
3642
+ return Builder.CreateOr (Res, X, " " , /* IsDisjoint=*/ true );
3643
+ }
3644
+
3645
+ if (match (LHS, m_OneUse (m_DisjointOr (m_Value (X), m_Value (Y))))) {
3646
+ if (Value *Res = foldDisjointOr (X, RHS))
3647
+ return Builder.CreateOr (Res, Y, " " , /* IsDisjoint=*/ true );
3648
+ if (Value *Res = foldDisjointOr (Y, RHS))
3649
+ return Builder.CreateOr (Res, X, " " , /* IsDisjoint=*/ true );
3650
+ }
3651
+
3652
+ return nullptr ;
3653
+ }
3654
+
3552
3655
// / Match \p V as "shufflevector -> bitcast" or "extractelement -> zext -> shl"
3553
3656
// / patterns, which extract vector elements and pack them in the same relative
3554
3657
// / positions.
@@ -3781,6 +3884,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
3781
3884
foldAddLikeCommutative (I.getOperand (1 ), I.getOperand (0 ),
3782
3885
/* NSW=*/ true , /* NUW=*/ true ))
3783
3886
return R;
3887
+
3888
+ if (Value *Res = foldDisjointOr (I.getOperand (0 ), I.getOperand (1 )))
3889
+ return replaceInstUsesWith (I, Res);
3890
+
3891
+ if (Value *Res = reassociateDisjointOr (I.getOperand (0 ), I.getOperand (1 )))
3892
+ return replaceInstUsesWith (I, Res);
3784
3893
}
3785
3894
3786
3895
Value *X, *Y;
0 commit comments