@@ -199,113 +199,132 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
199199// / the right hand side as a pair.
200200// / LHS and RHS are the left hand side and the right hand side ICmps and PredL
201201// / and PredR are their predicates, respectively.
202- static std::optional<std::pair<unsigned , unsigned >> getMaskedTypeForICmpPair (
203- Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
204- ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
205- // Don't allow pointers. Splat vectors are fine.
206- if (!LHS->getOperand (0 )->getType ()->isIntOrIntVectorTy () ||
207- !RHS->getOperand (0 )->getType ()->isIntOrIntVectorTy ())
208- return std::nullopt ;
202+ static std::optional<std::pair<unsigned , unsigned >>
203+ getMaskedTypeForICmpPair (Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
204+ Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
205+ ICmpInst::Predicate &PredR) {
209206
210- // Here comes the tricky part:
211- // LHS might be of the form L11 & L12 == X, X == L21 & L22,
212- // and L11 & L12 == L21 & L22. The same goes for RHS.
213- // Now we must find those components L** and R**, that are equal, so
214- // that we can extract the parameters A, B, C, D, and E for the canonical
215- // above.
216- Value *L1 = LHS->getOperand (0 );
217- Value *L2 = LHS->getOperand (1 );
218- Value *L11, *L12, *L21, *L22;
219- // Check whether the icmp can be decomposed into a bit test.
220- if (decomposeBitTestICmp (L1, L2, PredL, L11, L12, L2)) {
221- L21 = L22 = L1 = nullptr ;
222- } else {
223- // Look for ANDs in the LHS icmp.
224- if (!match (L1, m_And (m_Value (L11), m_Value (L12)))) {
225- // Any icmp can be viewed as being trivially masked; if it allows us to
226- // remove one, it's worth it.
227- L11 = L1;
228- L12 = Constant::getAllOnesValue (L1->getType ());
229- }
207+ Value *L1, *L11, *L12, *L2, *L21, *L22;
208+ if (auto *LHSCMP = dyn_cast<ICmpInst>(LHS)) {
209+
210+ // Don't allow pointers. Splat vectors are fine.
211+ if (!LHSCMP->getOperand (0 )->getType ()->isIntOrIntVectorTy ())
212+ return std::nullopt ;
213+
214+ PredL = LHSCMP->getPredicate ();
215+
216+ // Here comes the tricky part:
217+ // LHS might be of the form L11 & L12 == X, X == L21 & L22,
218+ // and L11 & L12 == L21 & L22. The same goes for RHS.
219+ // Now we must find those components L** and R**, that are equal, so
220+ // that we can extract the parameters A, B, C, D, and E for the canonical
221+ // above.
222+ L1 = LHSCMP->getOperand (0 );
223+ L2 = LHSCMP->getOperand (1 );
224+ // Check whether the icmp can be decomposed into a bit test.
225+ if (decomposeBitTestICmp (L1, L2, PredL, L11, L12, L2)) {
226+ L21 = L22 = L1 = nullptr ;
227+ } else {
228+ // Look for ANDs in the LHS icmp.
229+ if (!match (L1, m_And (m_Value (L11), m_Value (L12)))) {
230+ // Any icmp can be viewed as being trivially masked; if it allows us to
231+ // remove one, it's worth it.
232+ L11 = L1;
233+ L12 = Constant::getAllOnesValue (L1->getType ());
234+ }
230235
231- if (!match (L2, m_And (m_Value (L21), m_Value (L22)))) {
232- L21 = L2;
233- L22 = Constant::getAllOnesValue (L2->getType ());
236+ if (!match (L2, m_And (m_Value (L21), m_Value (L22)))) {
237+ L21 = L2;
238+ L22 = Constant::getAllOnesValue (L2->getType ());
239+ }
234240 }
235- }
241+ // Bail if LHS was a icmp that can't be decomposed into an equality.
242+ if (!ICmpInst::isEquality (PredL))
243+ return std::nullopt ;
236244
237- // Bail if LHS was a icmp that can't be decomposed into an equality.
238- if (!ICmpInst::isEquality (PredL))
245+ } else {
239246 return std::nullopt ;
247+ }
240248
241- Value *R1 = RHS->getOperand (0 );
242- Value *R2 = RHS->getOperand (1 );
243249 Value *R11, *R12;
244- bool Ok = false ;
245- if (decomposeBitTestICmp (R1, R2, PredR, R11, R12, R2)) {
246- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
247- A = R11;
248- D = R12;
249- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
250- A = R12;
251- D = R11;
252- } else {
250+ if (auto *RHSCMP = dyn_cast<ICmpInst>(RHS)) {
251+
252+ // Don't allow pointers. Splat vectors are fine.
253+ if (!RHSCMP->getOperand (0 )->getType ()->isIntOrIntVectorTy ())
253254 return std::nullopt ;
254- }
255- E = R2;
256- R1 = nullptr ;
257- Ok = true ;
258- } else {
259- if (!match (R1, m_And (m_Value (R11), m_Value (R12)))) {
260- // As before, model no mask as a trivial mask if it'll let us do an
261- // optimization.
262- R11 = R1;
263- R12 = Constant::getAllOnesValue (R1->getType ());
264- }
265255
266- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
267- A = R11;
268- D = R12;
269- E = R2;
270- Ok = true ;
271- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
272- A = R12;
273- D = R11;
256+ PredR = RHSCMP->getPredicate ();
257+
258+ Value *R1 = RHSCMP->getOperand (0 );
259+ Value *R2 = RHSCMP->getOperand (1 );
260+ bool Ok = false ;
261+ if (decomposeBitTestICmp (R1, R2, PredR, R11, R12, R2)) {
262+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
263+ A = R11;
264+ D = R12;
265+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
266+ A = R12;
267+ D = R11;
268+ } else {
269+ return std::nullopt ;
270+ }
274271 E = R2;
272+ R1 = nullptr ;
275273 Ok = true ;
276- }
277-
278- // Avoid matching against the -1 value we created for unmasked operand.
279- if (Ok && match (A, m_AllOnes ()))
280- Ok = false ;
281- }
274+ } else {
275+ if (!match (R1, m_And (m_Value (R11), m_Value (R12)))) {
276+ // As before, model no mask as a trivial mask if it'll let us do an
277+ // optimization.
278+ R11 = R1;
279+ R12 = Constant::getAllOnesValue (R1->getType ());
280+ }
282281
283- // Bail if RHS was a icmp that can't be decomposed into an equality.
284- if (!ICmpInst::isEquality (PredR))
285- return std::nullopt ;
282+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
283+ A = R11;
284+ D = R12;
285+ E = R2;
286+ Ok = true ;
287+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
288+ A = R12;
289+ D = R11;
290+ E = R2;
291+ Ok = true ;
292+ }
286293
287- // Look for ANDs on the right side of the RHS icmp.
288- if (!Ok) {
289- if (!match (R2, m_And (m_Value (R11), m_Value (R12)))) {
290- R11 = R2;
291- R12 = Constant::getAllOnesValue (R2->getType ());
294+ // Avoid matching against the -1 value we created for unmasked operand.
295+ if (Ok && match (A, m_AllOnes ()))
296+ Ok = false ;
292297 }
293298
294- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
295- A = R11;
296- D = R12;
297- E = R1;
298- Ok = true ;
299- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
300- A = R12;
301- D = R11;
302- E = R1;
303- Ok = true ;
304- } else {
299+ // Bail if RHS was a icmp that can't be decomposed into an equality.
300+ if (!ICmpInst::isEquality (PredR))
305301 return std::nullopt ;
306- }
307302
308- assert (Ok && " Failed to find AND on the right side of the RHS icmp." );
303+ // Look for ANDs on the right side of the RHS icmp.
304+ if (!Ok) {
305+ if (!match (R2, m_And (m_Value (R11), m_Value (R12)))) {
306+ R11 = R2;
307+ R12 = Constant::getAllOnesValue (R2->getType ());
308+ }
309+
310+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
311+ A = R11;
312+ D = R12;
313+ E = R1;
314+ Ok = true ;
315+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
316+ A = R12;
317+ D = R11;
318+ E = R1;
319+ Ok = true ;
320+ } else {
321+ return std::nullopt ;
322+ }
323+
324+ assert (Ok && " Failed to find AND on the right side of the RHS icmp." );
325+ }
326+ } else {
327+ return std::nullopt ;
309328 }
310329
311330 if (L11 == A) {
@@ -334,8 +353,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
334353// / (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
335354// / Also used for logical and/or, must be poison safe.
336355static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed (
337- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
338- Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
356+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E ,
357+ ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
339358 InstCombiner::BuilderTy &Builder) {
340359 // We are given the canonical form:
341360 // (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -458,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
458477 // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
459478 if (IsSuperSetOrEqual (BCst, DCst)) {
460479 // We can't guarantee that samesign hold after this fold.
461- RHS->setSameSign (false );
480+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
481+ ICmp->setSameSign (false );
462482 return RHS;
463483 }
464484 // Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -467,7 +487,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
467487 assert (IsSubSetOrEqual (BCst, DCst) && " Precondition due to above code" );
468488 if ((*BCst & ECst) != 0 ) {
469489 // We can't guarantee that samesign hold after this fold.
470- RHS->setSameSign (false );
490+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
491+ ICmp->setSameSign (false );
471492 return RHS;
472493 }
473494 // Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -482,8 +503,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
482503// / aren't of the common mask pattern type.
483504// / Also used for logical and/or, must be poison safe.
484505static Value *foldLogOpOfMaskedICmpsAsymmetric (
485- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
486- Value *D, Value * E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
506+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D ,
507+ Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
487508 unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
488509 assert (ICmpInst::isEquality (PredL) && ICmpInst::isEquality (PredR) &&
489510 " Expected equality predicates for masked type of icmps." );
@@ -512,12 +533,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
512533
513534// / Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
514535// / into a single (icmp(A & X) ==/!= Y).
515- static Value *foldLogOpOfMaskedICmps (ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
536+ static Value *foldLogOpOfMaskedICmps (Value *LHS, Value *RHS, bool IsAnd,
516537 bool IsLogical,
517538 InstCombiner::BuilderTy &Builder,
518539 const SimplifyQuery &Q) {
519540 Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr , *E = nullptr ;
520- ICmpInst::Predicate PredL = LHS-> getPredicate () , PredR = RHS-> getPredicate () ;
541+ ICmpInst::Predicate PredL, PredR;
521542 std::optional<std::pair<unsigned , unsigned >> MaskPair =
522543 getMaskedTypeForICmpPair (A, B, C, D, E, LHS, RHS, PredL, PredR);
523544 if (!MaskPair)
@@ -1067,8 +1088,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
10671088 if (!JoinedByAnd)
10681089 return nullptr ;
10691090 Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr , *E = nullptr ;
1070- ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate (),
1071- CmpPred1 = Cmp1->getPredicate ();
1091+ ICmpInst::Predicate CmpPred0, CmpPred1;
10721092 // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
10731093 // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
10741094 // SignMask) == 0).
@@ -3325,12 +3345,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
33253345 }
33263346 }
33273347
3328- // handle (roughly):
3329- // (icmp ne (A & B), C) | (icmp ne (A & D), E)
3330- // (icmp eq (A & B), C) & (icmp eq (A & D), E)
3331- if (Value *V = foldLogOpOfMaskedICmps (LHS, RHS, IsAnd, IsLogical, Builder, Q))
3332- return V;
3333-
33343348 if (Value *V =
33353349 foldAndOrOfICmpEqConstantAndICmp (LHS, RHS, IsAnd, IsLogical, Builder))
33363350 return V;
@@ -3510,6 +3524,12 @@ Value *InstCombinerImpl::foldBooleanAndOr(Value *LHS, Value *RHS,
35103524 if (Value *Res = foldAndOrOfICmps (LHSCmp, RHSCmp, I, IsAnd, IsLogical))
35113525 return Res;
35123526
3527+ // / Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
3528+ // / into a single (icmp(A & X) ==/!= Y).
3529+ if (Value *V = foldLogOpOfMaskedICmps (LHS, RHS, IsAnd, IsLogical, Builder,
3530+ SQ.getWithInstruction (&I)))
3531+ return V;
3532+
35133533 if (auto *LHSCmp = dyn_cast<FCmpInst>(LHS))
35143534 if (auto *RHSCmp = dyn_cast<FCmpInst>(RHS))
35153535 if (Value *Res = foldLogicOfFCmps (LHSCmp, RHSCmp, IsAnd, IsLogical))
0 commit comments