@@ -119,7 +119,7 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
119119// / (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
120120// / With some variations depending if FC is larger than TC, or the shift
121121// / isn't needed, or the bit widths don't match.
122- static Value *foldSelectICmpAnd (SelectInst &Sel, ICmpInst *Cmp ,
122+ static Value *foldSelectICmpAnd (SelectInst &Sel, Value *CondVal ,
123123 InstCombiner::BuilderTy &Builder,
124124 const SimplifyQuery &SQ) {
125125 const APInt *SelTC, *SelFC;
@@ -129,36 +129,47 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
129129
130130 // If this is a vector select, we need a vector compare.
131131 Type *SelType = Sel.getType ();
132- if (SelType->isVectorTy () != Cmp ->getType ()->isVectorTy ())
132+ if (SelType->isVectorTy () != CondVal ->getType ()->isVectorTy ())
133133 return nullptr ;
134134
135135 Value *V;
136136 APInt AndMask;
137137 bool CreateAnd = false ;
138- ICmpInst::Predicate Pred = Cmp->getPredicate ();
139- if (ICmpInst::isEquality (Pred)) {
140- if (!match (Cmp->getOperand (1 ), m_Zero ()))
141- return nullptr ;
138+ CmpPredicate Pred;
139+ Value *CmpLHS, *CmpRHS;
142140
143- V = Cmp->getOperand (0 );
144- const APInt *AndRHS;
145- if (!match (V, m_And (m_Value (), m_Power2 (AndRHS))))
146- return nullptr ;
141+ if (match (CondVal, m_ICmp (Pred, m_Value (CmpLHS), m_Value (CmpRHS)))) {
142+ if (ICmpInst::isEquality (Pred)) {
143+ if (!match (CmpRHS, m_Zero ()))
144+ return nullptr ;
145+
146+ V = CmpLHS;
147+ const APInt *AndRHS;
148+ if (!match (V, m_And (m_Value (), m_Power2 (AndRHS))))
149+ return nullptr ;
147150
148- AndMask = *AndRHS;
149- } else if (auto Res = decomposeBitTestICmp (Cmp->getOperand (0 ),
150- Cmp->getOperand (1 ), Pred)) {
151- assert (ICmpInst::isEquality (Res->Pred ) && " Not equality test?" );
152- AndMask = Res->Mask ;
153- V = Res->X ;
154- KnownBits Known =
155- computeKnownBits (V, /* Depth=*/ 0 , SQ.getWithInstruction (&Sel));
156- AndMask &= Known.getMaxValue ();
157- if (!AndMask.isPowerOf2 ())
151+ AndMask = *AndRHS;
152+ } else if (auto Res = decomposeBitTestICmp (CmpLHS, CmpRHS, Pred)) {
153+ assert (ICmpInst::isEquality (Res->Pred ) && " Not equality test?" );
154+ AndMask = Res->Mask ;
155+ V = Res->X ;
156+ KnownBits Known =
157+ computeKnownBits (V, /* Depth=*/ 0 , SQ.getWithInstruction (&Sel));
158+ AndMask &= Known.getMaxValue ();
159+ if (!AndMask.isPowerOf2 ())
160+ return nullptr ;
161+
162+ Pred = Res->Pred ;
163+ CreateAnd = true ;
164+ } else {
158165 return nullptr ;
166+ }
159167
160- Pred = Res->Pred ;
161- CreateAnd = true ;
168+ } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
169+ V = Trunc->getOperand (0 );
170+ AndMask = APInt (V->getType ()->getScalarSizeInBits (), 1 );
171+ Pred = ICmpInst::ICMP_NE;
172+ CreateAnd = !Trunc->hasNoUnsignedWrap ();
162173 } else {
163174 return nullptr ;
164175 }
@@ -176,7 +187,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
176187 return nullptr ;
177188 // If we have to create an 'and', then we must kill the cmp to not
178189 // increase the instruction count.
179- if (CreateAnd && !Cmp ->hasOneUse ())
190+ if (CreateAnd && !CondVal ->hasOneUse ())
180191 return nullptr ;
181192
182193 // (V & AndMaskC) == 0 ? TC : FC --> TC | (V & AndMaskC)
@@ -217,7 +228,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
217228 // a 'select' + 'icmp', then this transformation would result in more
218229 // instructions and potentially interfere with other folding.
219230 if (CreateAnd + ShouldNotVal + NeedShift + NeedZExtTrunc >
220- 1 + Cmp ->hasOneUse ())
231+ 1 + CondVal ->hasOneUse ())
221232 return nullptr ;
222233
223234 // Insert the 'and' instruction on the input to the truncate.
@@ -1961,9 +1972,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
19611972 tryToReuseConstantFromSelectInComparison (SI, *ICI, *this ))
19621973 return NewSel;
19631974
1964- if (Value *V = foldSelectICmpAnd (SI, ICI, Builder, SQ))
1965- return replaceInstUsesWith (SI, V);
1966-
19671975 // NOTE: if we wanted to, this is where to detect integer MIN/MAX
19681976 bool Changed = false ;
19691977 Value *TrueVal = SI.getTrueValue ();
@@ -3961,6 +3969,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
39613969 if (Instruction *Result = foldSelectInstWithICmp (SI, ICI))
39623970 return Result;
39633971
3972+ if (Value *V = foldSelectICmpAnd (SI, CondVal, Builder, SQ))
3973+ return replaceInstUsesWith (SI, V);
3974+
39643975 if (Value *V = foldSelectICmpAndBinOp (CondVal, TrueVal, FalseVal, Builder))
39653976 return replaceInstUsesWith (SI, V);
39663977
0 commit comments