@@ -119,63 +119,15 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
119119// / (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
120120// / With some variations depending if FC is larger than TC, or the shift
121121// / isn't needed, or the bit widths don't match.
122- static Value *foldSelectICmpAnd (SelectInst &Sel, Value *CondVal,
123- InstCombiner::BuilderTy &Builder,
124- const SimplifyQuery &SQ) {
122+ static Value *foldSelectICmpAnd (SelectInst &Sel, Value *CondVal, Value *TrueVal,
123+ Value *FalseVal, Value *V, const APInt &AndMask,
124+ bool CreateAnd,
125+ InstCombiner::BuilderTy &Builder) {
125126 const APInt *SelTC, *SelFC;
126- if (!match (Sel.getTrueValue (), m_APInt (SelTC)) ||
127- !match (Sel.getFalseValue (), m_APInt (SelFC)))
127+ if (!match (TrueVal, m_APInt (SelTC)) || !match (FalseVal, m_APInt (SelFC)))
128128 return nullptr ;
129129
130- // If this is a vector select, we need a vector compare.
131130 Type *SelType = Sel.getType ();
132- if (SelType->isVectorTy () != CondVal->getType ()->isVectorTy ())
133- return nullptr ;
134-
135- Value *V;
136- APInt AndMask;
137- bool CreateAnd = false ;
138- CmpPredicate Pred;
139- Value *CmpLHS, *CmpRHS;
140-
141- if (match (CondVal, m_ICmp (Pred, m_Value (CmpLHS), m_Value (CmpRHS)))) {
142- if (ICmpInst::isEquality (Pred)) {
143- if (!match (CmpRHS, m_Zero ()))
144- return nullptr ;
145-
146- V = CmpLHS;
147- const APInt *AndRHS;
148- if (!match (V, m_And (m_Value (), m_Power2 (AndRHS))))
149- return nullptr ;
150-
151- AndMask = *AndRHS;
152- } else if (auto Res = decomposeBitTestICmp (CmpLHS, CmpRHS, Pred)) {
153- assert (ICmpInst::isEquality (Res->Pred ) && " Not equality test?" );
154- AndMask = Res->Mask ;
155- V = Res->X ;
156- KnownBits Known =
157- computeKnownBits (V, /* Depth=*/ 0 , SQ.getWithInstruction (&Sel));
158- AndMask &= Known.getMaxValue ();
159- if (!AndMask.isPowerOf2 ())
160- return nullptr ;
161-
162- Pred = Res->Pred ;
163- CreateAnd = true ;
164- } else {
165- return nullptr ;
166- }
167-
168- } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
169- V = Trunc->getOperand (0 );
170- AndMask = APInt (V->getType ()->getScalarSizeInBits (), 1 );
171- Pred = ICmpInst::ICMP_NE;
172- CreateAnd = !Trunc->hasNoUnsignedWrap ();
173- } else {
174- return nullptr ;
175- }
176- if (Pred == ICmpInst::ICMP_NE)
177- std::swap (SelTC, SelFC);
178-
179131 // In general, when both constants are non-zero, we would need an offset to
180132 // replace the select. This would require more instructions than we started
181133 // with. But there's one special-case that we handle here because it can
@@ -762,60 +714,26 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
762714// / 2. The select operands are reversed
763715// / 3. The magnitude of C2 and C1 are flipped
764716static Value *foldSelectICmpAndBinOp (Value *CondVal, Value *TrueVal,
765- Value *FalseVal,
717+ Value *FalseVal, Value *V,
718+ const APInt &AndMask, bool CreateAnd,
766719 InstCombiner::BuilderTy &Builder) {
767- // Only handle integer compares. Also, if this is a vector select, we need a
768- // vector compare.
769- if (!TrueVal->getType ()->isIntOrIntVectorTy () ||
770- TrueVal->getType ()->isVectorTy () != CondVal->getType ()->isVectorTy ())
771- return nullptr ;
772-
773- unsigned C1Log;
774- bool NeedAnd = false ;
775- CmpPredicate Pred;
776- Value *CmpLHS, *CmpRHS;
777-
778- if (match (CondVal, m_ICmp (Pred, m_Value (CmpLHS), m_Value (CmpRHS)))) {
779- if (ICmpInst::isEquality (Pred)) {
780- if (!match (CmpRHS, m_Zero ()))
781- return nullptr ;
782-
783- const APInt *C1;
784- if (!match (CmpLHS, m_And (m_Value (), m_Power2 (C1))))
785- return nullptr ;
786-
787- C1Log = C1->logBase2 ();
788- } else {
789- auto Res = decomposeBitTestICmp (CmpLHS, CmpRHS, Pred);
790- if (!Res || !Res->Mask .isPowerOf2 ())
791- return nullptr ;
792-
793- CmpLHS = Res->X ;
794- Pred = Res->Pred ;
795- C1Log = Res->Mask .logBase2 ();
796- NeedAnd = true ;
797- }
798- } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
799- CmpLHS = Trunc->getOperand (0 );
800- C1Log = 0 ;
801- Pred = ICmpInst::ICMP_NE;
802- NeedAnd = !Trunc->hasNoUnsignedWrap ();
803- } else {
720+ // Only handle integer compares.
721+ if (!TrueVal->getType ()->isIntOrIntVectorTy ())
804722 return nullptr ;
805- }
806723
807- Value *Y, *V = CmpLHS;
724+ unsigned C1Log = AndMask.logBase2 ();
725+ Value *Y;
808726 BinaryOperator *BinOp;
809727 const APInt *C2;
810728 bool NeedXor;
811729 if (match (FalseVal, m_BinOp (m_Specific (TrueVal), m_Power2 (C2)))) {
812730 Y = TrueVal;
813731 BinOp = cast<BinaryOperator>(FalseVal);
814- NeedXor = Pred == ICmpInst::ICMP_NE ;
732+ NeedXor = false ;
815733 } else if (match (TrueVal, m_BinOp (m_Specific (FalseVal), m_Power2 (C2)))) {
816734 Y = FalseVal;
817735 BinOp = cast<BinaryOperator>(TrueVal);
818- NeedXor = Pred == ICmpInst::ICMP_EQ ;
736+ NeedXor = true ;
819737 } else {
820738 return nullptr ;
821739 }
@@ -834,14 +752,13 @@ static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
834752 V->getType ()->getScalarSizeInBits ();
835753
836754 // Make sure we don't create more instructions than we save.
837- if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd ) >
755+ if ((NeedShift + NeedXor + NeedZExtTrunc + CreateAnd ) >
838756 (CondVal->hasOneUse () + BinOp->hasOneUse ()))
839757 return nullptr ;
840758
841- if (NeedAnd ) {
759+ if (CreateAnd ) {
842760 // Insert the AND instruction on the input to the truncate.
843- APInt C1 = APInt::getOneBitSet (V->getType ()->getScalarSizeInBits (), C1Log);
844- V = Builder.CreateAnd (V, ConstantInt::get (V->getType (), C1));
761+ V = Builder.CreateAnd (V, ConstantInt::get (V->getType (), AndMask));
845762 }
846763
847764 if (C2Log > C1Log) {
@@ -3797,6 +3714,70 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
37973714 return nullptr ;
37983715}
37993716
3717+ static Value *foldSelectBitTest (SelectInst &Sel, Value *CondVal, Value *TrueVal,
3718+ Value *FalseVal,
3719+ InstCombiner::BuilderTy &Builder,
3720+ const SimplifyQuery &SQ) {
3721+ // If this is a vector select, we need a vector compare.
3722+ Type *SelType = Sel.getType ();
3723+ if (SelType->isVectorTy () != CondVal->getType ()->isVectorTy ())
3724+ return nullptr ;
3725+
3726+ Value *V;
3727+ APInt AndMask;
3728+ bool CreateAnd = false ;
3729+ CmpPredicate Pred;
3730+ Value *CmpLHS, *CmpRHS;
3731+
3732+ if (match (CondVal, m_ICmp (Pred, m_Value (CmpLHS), m_Value (CmpRHS)))) {
3733+ if (ICmpInst::isEquality (Pred)) {
3734+ if (!match (CmpRHS, m_Zero ()))
3735+ return nullptr ;
3736+
3737+ V = CmpLHS;
3738+ const APInt *AndRHS;
3739+ if (!match (CmpLHS, m_And (m_Value (), m_Power2 (AndRHS))))
3740+ return nullptr ;
3741+
3742+ AndMask = *AndRHS;
3743+ } else if (auto Res = decomposeBitTestICmp (CmpLHS, CmpRHS, Pred)) {
3744+ assert (ICmpInst::isEquality (Res->Pred ) && " Not equality test?" );
3745+ AndMask = Res->Mask ;
3746+ V = Res->X ;
3747+ KnownBits Known =
3748+ computeKnownBits (V, /* Depth=*/ 0 , SQ.getWithInstruction (&Sel));
3749+ AndMask &= Known.getMaxValue ();
3750+ if (!AndMask.isPowerOf2 ())
3751+ return nullptr ;
3752+
3753+ Pred = Res->Pred ;
3754+ CreateAnd = true ;
3755+ } else {
3756+ return nullptr ;
3757+ }
3758+ } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
3759+ V = Trunc->getOperand (0 );
3760+ AndMask = APInt (V->getType ()->getScalarSizeInBits (), 1 );
3761+ Pred = ICmpInst::ICMP_NE;
3762+ CreateAnd = !Trunc->hasNoUnsignedWrap ();
3763+ } else {
3764+ return nullptr ;
3765+ }
3766+
3767+ if (Pred == ICmpInst::ICMP_NE)
3768+ std::swap (TrueVal, FalseVal);
3769+
3770+ if (Value *X = foldSelectICmpAnd (Sel, CondVal, TrueVal, FalseVal, V, AndMask,
3771+ CreateAnd, Builder))
3772+ return X;
3773+
3774+ if (Value *X = foldSelectICmpAndBinOp (CondVal, TrueVal, FalseVal, V, AndMask,
3775+ CreateAnd, Builder))
3776+ return X;
3777+
3778+ return nullptr ;
3779+ }
3780+
38003781Instruction *InstCombinerImpl::visitSelectInst (SelectInst &SI) {
38013782 Value *CondVal = SI.getCondition ();
38023783 Value *TrueVal = SI.getTrueValue ();
@@ -3969,10 +3950,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
39693950 if (Instruction *Result = foldSelectInstWithICmp (SI, ICI))
39703951 return Result;
39713952
3972- if (Value *V = foldSelectICmpAnd (SI, CondVal, Builder, SQ))
3973- return replaceInstUsesWith (SI, V);
3974-
3975- if (Value *V = foldSelectICmpAndBinOp (CondVal, TrueVal, FalseVal, Builder))
3953+ if (Value *V = foldSelectBitTest (SI, CondVal, TrueVal, FalseVal, Builder, SQ))
39763954 return replaceInstUsesWith (SI, V);
39773955
39783956 if (Instruction *Add = foldAddSubSelect (SI, Builder))
0 commit comments