@@ -179,10 +179,10 @@ static unsigned conjugateICmpMask(unsigned Mask) {
179179}
180180
181181// Adapts the external decomposeBitTestICmp for local use.
182- static bool decomposeBitTestICmp (Value *LHS, Value *RHS , CmpInst::Predicate &Pred,
182+ static bool decomposeBitTestICmp (Value *Cond , CmpInst::Predicate &Pred,
183183 Value *&X, Value *&Y, Value *&Z) {
184- auto Res = llvm::decomposeBitTestICmp (
185- LHS, RHS, Pred, /* LookThroughTrunc= */ true , /* AllowNonZeroC=*/ true );
184+ auto Res = llvm::decomposeBitTest (Cond, /* LookThroughTrunc= */ true ,
185+ /* AllowNonZeroC=*/ true );
186186 if (!Res)
187187 return false ;
188188
@@ -198,27 +198,34 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
198198// / the right hand side as a pair.
199199// / LHS and RHS are the left hand side and the right hand side ICmps and PredL
200200// / and PredR are their predicates, respectively.
201- static std::optional<std::pair<unsigned , unsigned >> getMaskedTypeForICmpPair (
202- Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
203- ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
204- // Don't allow pointers. Splat vectors are fine.
205- if (!LHS->getOperand (0 )->getType ()->isIntOrIntVectorTy () ||
206- !RHS->getOperand (0 )->getType ()->isIntOrIntVectorTy ())
207- return std::nullopt ;
201+ static std::optional<std::pair<unsigned , unsigned >>
202+ getMaskedTypeForICmpPair (Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
203+ Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
204+ ICmpInst::Predicate &PredR) {
208205
209206 // Here comes the tricky part:
210207 // LHS might be of the form L11 & L12 == X, X == L21 & L22,
211208 // and L11 & L12 == L21 & L22. The same goes for RHS.
212209 // Now we must find those components L** and R**, that are equal, so
213210 // that we can extract the parameters A, B, C, D, and E for the canonical
214211 // above.
215- Value *L1 = LHS->getOperand (0 );
216- Value *L2 = LHS->getOperand (1 );
217- Value *L11, *L12, *L21, *L22;
212+
218213 // Check whether the icmp can be decomposed into a bit test.
219- if (decomposeBitTestICmp (L1, L2, PredL, L11, L12, L2)) {
214+ Value *L1, *L11, *L12, *L2, *L21, *L22;
215+ if (decomposeBitTestICmp (LHS, PredL, L11, L12, L2)) {
220216 L21 = L22 = L1 = nullptr ;
221217 } else {
218+ auto *LHSCMP = dyn_cast<ICmpInst>(LHS);
219+ if (!LHSCMP)
220+ return std::nullopt ;
221+
222+ // Don't allow pointers. Splat vectors are fine.
223+ if (!LHSCMP->getOperand (0 )->getType ()->isIntOrIntVectorTy ())
224+ return std::nullopt ;
225+
226+ PredL = LHSCMP->getPredicate ();
227+ L1 = LHSCMP->getOperand (0 );
228+ L2 = LHSCMP->getOperand (1 );
222229 // Look for ANDs in the LHS icmp.
223230 if (!match (L1, m_And (m_Value (L11), m_Value (L12)))) {
224231 // Any icmp can be viewed as being trivially masked; if it allows us to
@@ -237,11 +244,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
237244 if (!ICmpInst::isEquality (PredL))
238245 return std::nullopt ;
239246
240- Value *R1 = RHS->getOperand (0 );
241- Value *R2 = RHS->getOperand (1 );
242- Value *R11, *R12;
243- bool Ok = false ;
244- if (decomposeBitTestICmp (R1, R2, PredR, R11, R12, R2)) {
247+ Value *R11, *R12, *R2;
248+ if (decomposeBitTestICmp (RHS, PredR, R11, R12, R2)) {
245249 if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
246250 A = R11;
247251 D = R12;
@@ -252,9 +256,19 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
252256 return std::nullopt ;
253257 }
254258 E = R2;
255- R1 = nullptr ;
256- Ok = true ;
257259 } else {
260+ auto *RHSCMP = dyn_cast<ICmpInst>(RHS);
261+ if (!RHSCMP)
262+ return std::nullopt ;
263+ // Don't allow pointers. Splat vectors are fine.
264+ if (!RHSCMP->getOperand (0 )->getType ()->isIntOrIntVectorTy ())
265+ return std::nullopt ;
266+
267+ PredR = RHSCMP->getPredicate ();
268+
269+ Value *R1 = RHSCMP->getOperand (0 );
270+ R2 = RHSCMP->getOperand (1 );
271+ bool Ok = false ;
258272 if (!match (R1, m_And (m_Value (R11), m_Value (R12)))) {
259273 // As before, model no mask as a trivial mask if it'll let us do an
260274 // optimization.
@@ -277,36 +291,32 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
277291 // Avoid matching against the -1 value we created for unmasked operand.
278292 if (Ok && match (A, m_AllOnes ()))
279293 Ok = false ;
294+
295+ // Look for ANDs on the right side of the RHS icmp.
296+ if (!Ok) {
297+ if (!match (R2, m_And (m_Value (R11), m_Value (R12)))) {
298+ R11 = R2;
299+ R12 = Constant::getAllOnesValue (R2->getType ());
300+ }
301+
302+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
303+ A = R11;
304+ D = R12;
305+ E = R1;
306+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
307+ A = R12;
308+ D = R11;
309+ E = R1;
310+ } else {
311+ return std::nullopt ;
312+ }
313+ }
280314 }
281315
282316 // Bail if RHS was a icmp that can't be decomposed into an equality.
283317 if (!ICmpInst::isEquality (PredR))
284318 return std::nullopt ;
285319
286- // Look for ANDs on the right side of the RHS icmp.
287- if (!Ok) {
288- if (!match (R2, m_And (m_Value (R11), m_Value (R12)))) {
289- R11 = R2;
290- R12 = Constant::getAllOnesValue (R2->getType ());
291- }
292-
293- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
294- A = R11;
295- D = R12;
296- E = R1;
297- Ok = true ;
298- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
299- A = R12;
300- D = R11;
301- E = R1;
302- Ok = true ;
303- } else {
304- return std::nullopt ;
305- }
306-
307- assert (Ok && " Failed to find AND on the right side of the RHS icmp." );
308- }
309-
310320 if (L11 == A) {
311321 B = L12;
312322 C = L2;
@@ -333,8 +343,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
333343// / (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
334344// / Also used for logical and/or, must be poison safe.
335345static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed (
336- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
337- Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
346+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E ,
347+ ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
338348 InstCombiner::BuilderTy &Builder) {
339349 // We are given the canonical form:
340350 // (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -457,7 +467,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
457467 // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
458468 if (IsSuperSetOrEqual (BCst, DCst)) {
459469 // We can't guarantee that samesign hold after this fold.
460- RHS->setSameSign (false );
470+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
471+ ICmp->setSameSign (false );
461472 return RHS;
462473 }
463474 // Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -466,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
466477 assert (IsSubSetOrEqual (BCst, DCst) && " Precondition due to above code" );
467478 if ((*BCst & ECst) != 0 ) {
468479 // We can't guarantee that samesign hold after this fold.
469- RHS->setSameSign (false );
480+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
481+ ICmp->setSameSign (false );
470482 return RHS;
471483 }
472484 // Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -481,8 +493,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
481493// / aren't of the common mask pattern type.
482494// / Also used for logical and/or, must be poison safe.
483495static Value *foldLogOpOfMaskedICmpsAsymmetric (
484- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
485- Value *D, Value * E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
496+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D ,
497+ Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
486498 unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
487499 assert (ICmpInst::isEquality (PredL) && ICmpInst::isEquality (PredR) &&
488500 " Expected equality predicates for masked type of icmps." );
@@ -511,12 +523,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
511523
512524// / Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
513525// / into a single (icmp(A & X) ==/!= Y).
514- static Value *foldLogOpOfMaskedICmps (ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
526+ static Value *foldLogOpOfMaskedICmps (Value *LHS, Value *RHS, bool IsAnd,
515527 bool IsLogical,
516528 InstCombiner::BuilderTy &Builder,
517529 const SimplifyQuery &Q) {
518530 Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr , *E = nullptr ;
519- ICmpInst::Predicate PredL = LHS-> getPredicate () , PredR = RHS-> getPredicate () ;
531+ ICmpInst::Predicate PredL, PredR;
520532 std::optional<std::pair<unsigned , unsigned >> MaskPair =
521533 getMaskedTypeForICmpPair (A, B, C, D, E, LHS, RHS, PredL, PredR);
522534 if (!MaskPair)
@@ -1066,8 +1078,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
10661078 if (!JoinedByAnd)
10671079 return nullptr ;
10681080 Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr , *E = nullptr ;
1069- ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate (),
1070- CmpPred1 = Cmp1->getPredicate ();
1081+ ICmpInst::Predicate CmpPred0, CmpPred1;
10711082 // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
10721083 // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
10731084 // SignMask) == 0).
0 commit comments