@@ -6544,6 +6544,35 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
65446544 return false ;
65456545}
65466546
6547+ static std::optional<bool > compareKnownBits (ICmpInst::Predicate Pred,
6548+ const KnownBits &Op0,
6549+ const KnownBits &Op1) {
6550+ switch (Pred) {
6551+ case ICmpInst::ICMP_EQ:
6552+ return KnownBits::eq (Op0, Op1);
6553+ case ICmpInst::ICMP_NE:
6554+ return KnownBits::ne (Op0, Op1);
6555+ case ICmpInst::ICMP_ULT:
6556+ return KnownBits::ult (Op0, Op1);
6557+ case ICmpInst::ICMP_ULE:
6558+ return KnownBits::ule (Op0, Op1);
6559+ case ICmpInst::ICMP_UGT:
6560+ return KnownBits::ugt (Op0, Op1);
6561+ case ICmpInst::ICMP_UGE:
6562+ return KnownBits::uge (Op0, Op1);
6563+ case ICmpInst::ICMP_SLT:
6564+ return KnownBits::slt (Op0, Op1);
6565+ case ICmpInst::ICMP_SLE:
6566+ return KnownBits::sle (Op0, Op1);
6567+ case ICmpInst::ICMP_SGT:
6568+ return KnownBits::sgt (Op0, Op1);
6569+ case ICmpInst::ICMP_SGE:
6570+ return KnownBits::sge (Op0, Op1);
6571+ default :
6572+ llvm_unreachable (" Unknown predicate" );
6573+ }
6574+ }
6575+
65476576// / Try to fold the comparison based on range information we can get by checking
65486577// / whether bits are known to be zero or one in the inputs.
65496578Instruction *InstCombinerImpl::foldICmpUsingKnownBits (ICmpInst &I) {
@@ -6576,6 +6605,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
65766605 return &I;
65776606 }
65786607
6608+ if (!isa<Constant>(Op0) && Op0Known.isConstant ())
6609+ return new ICmpInst (
6610+ Pred, ConstantExpr::getIntegerValue (Ty, Op0Known.getConstant ()), Op1);
6611+ if (!isa<Constant>(Op1) && Op1Known.isConstant ())
6612+ return new ICmpInst (
6613+ Pred, Op0, ConstantExpr::getIntegerValue (Ty, Op1Known.getConstant ()));
6614+
6615+ if (std::optional<bool > Res = compareKnownBits (Pred, Op0Known, Op1Known))
6616+ return replaceInstUsesWith (I, ConstantInt::getBool (I.getType (), *Res));
6617+
65796618 // Given the known and unknown bits, compute a range that the LHS could be
65806619 // in. Compute the Min, Max and RHS values based on the known bits. For the
65816620 // EQ and NE we use unsigned values.
@@ -6593,14 +6632,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
65936632 Op1Max = Op1Known.getMaxValue ();
65946633 }
65956634
6596- // If Min and Max are known to be the same, then SimplifyDemandedBits figured
6597- // out that the LHS or RHS is a constant. Constant fold this now, so that
6598- // code below can assume that Min != Max.
6599- if (!isa<Constant>(Op0) && Op0Min == Op0Max)
6600- return new ICmpInst (Pred, ConstantExpr::getIntegerValue (Ty, Op0Min), Op1);
6601- if (!isa<Constant>(Op1) && Op1Min == Op1Max)
6602- return new ICmpInst (Pred, Op0, ConstantExpr::getIntegerValue (Ty, Op1Min));
6603-
66046635 // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
66056636 // min/max canonical compare with some other compare. That could lead to
66066637 // conflict with select canonicalization and infinite looping.
@@ -6682,13 +6713,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
66826713 // simplify this comparison. For example, (x&4) < 8 is always true.
66836714 switch (Pred) {
66846715 default :
6685- llvm_unreachable ( " Unknown icmp opcode! " ) ;
6716+ break ;
66866717 case ICmpInst::ICMP_EQ:
66876718 case ICmpInst::ICMP_NE: {
6688- if (Op0Max.ult (Op1Min) || Op0Min.ugt (Op1Max))
6689- return replaceInstUsesWith (
6690- I, ConstantInt::getBool (I.getType (), Pred == CmpInst::ICMP_NE));
6691-
66926719 // If all bits are known zero except for one, then we know at most one bit
66936720 // is set. If the comparison is against zero, then this is a check to see if
66946721 // *that* bit is set.
@@ -6728,67 +6755,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
67286755 ConstantInt::getNullValue (Op1->getType ()));
67296756 break ;
67306757 }
6731- case ICmpInst::ICMP_ULT: {
6732- if (Op0Max.ult (Op1Min)) // A <u B -> true if max(A) < min(B)
6733- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6734- if (Op0Min.uge (Op1Max)) // A <u B -> false if min(A) >= max(B)
6735- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6736- break ;
6737- }
6738- case ICmpInst::ICMP_UGT: {
6739- if (Op0Min.ugt (Op1Max)) // A >u B -> true if min(A) > max(B)
6740- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6741- if (Op0Max.ule (Op1Min)) // A >u B -> false if max(A) <= max(B)
6742- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6743- break ;
6744- }
6745- case ICmpInst::ICMP_SLT: {
6746- if (Op0Max.slt (Op1Min)) // A <s B -> true if max(A) < min(C)
6747- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6748- if (Op0Min.sge (Op1Max)) // A <s B -> false if min(A) >= max(C)
6749- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6750- break ;
6751- }
6752- case ICmpInst::ICMP_SGT: {
6753- if (Op0Min.sgt (Op1Max)) // A >s B -> true if min(A) > max(B)
6754- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6755- if (Op0Max.sle (Op1Min)) // A >s B -> false if max(A) <= min(B)
6756- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6757- break ;
6758- }
67596758 case ICmpInst::ICMP_SGE:
6760- assert (!isa<ConstantInt>(Op1) && " ICMP_SGE with ConstantInt not folded!" );
6761- if (Op0Min.sge (Op1Max)) // A >=s B -> true if min(A) >= max(B)
6762- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6763- if (Op0Max.slt (Op1Min)) // A >=s B -> false if max(A) < min(B)
6764- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
67656759 if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
67666760 return new ICmpInst (ICmpInst::ICMP_EQ, Op0, Op1);
67676761 break ;
67686762 case ICmpInst::ICMP_SLE:
6769- assert (!isa<ConstantInt>(Op1) && " ICMP_SLE with ConstantInt not folded!" );
6770- if (Op0Max.sle (Op1Min)) // A <=s B -> true if max(A) <= min(B)
6771- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6772- if (Op0Min.sgt (Op1Max)) // A <=s B -> false if min(A) > max(B)
6773- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
67746763 if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
67756764 return new ICmpInst (ICmpInst::ICMP_EQ, Op0, Op1);
67766765 break ;
67776766 case ICmpInst::ICMP_UGE:
6778- assert (!isa<ConstantInt>(Op1) && " ICMP_UGE with ConstantInt not folded!" );
6779- if (Op0Min.uge (Op1Max)) // A >=u B -> true if min(A) >= max(B)
6780- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6781- if (Op0Max.ult (Op1Min)) // A >=u B -> false if max(A) < min(B)
6782- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
67836767 if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
67846768 return new ICmpInst (ICmpInst::ICMP_EQ, Op0, Op1);
67856769 break ;
67866770 case ICmpInst::ICMP_ULE:
6787- assert (!isa<ConstantInt>(Op1) && " ICMP_ULE with ConstantInt not folded!" );
6788- if (Op0Max.ule (Op1Min)) // A <=u B -> true if max(A) <= min(B)
6789- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6790- if (Op0Min.ugt (Op1Max)) // A <=u B -> false if min(A) > max(B)
6791- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
67926771 if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
67936772 return new ICmpInst (ICmpInst::ICMP_EQ, Op0, Op1);
67946773 break ;
0 commit comments