diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h index b6575d4c85724..bc29a4801e4ff 100644 --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -48,6 +48,7 @@ class APInt; class BasicBlock; class ConstantInt; class DataLayout; +struct KnownBits; class StringRef; class Type; class Value; @@ -1305,6 +1306,11 @@ class ICmpInst: public CmpInst { static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred); + /// Return result of `LHS Pred RHS`, if it can be determined from the + /// KnownBits. Otherwise return nullopt. + static std::optional compare(const KnownBits &LHS, const KnownBits &RHS, + ICmpInst::Predicate Pred); + // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const Instruction *I) { return I->getOpcode() == Instruction::ICmp; diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp index 0945e7334ac9d..1a1a1c28ef150 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -4442,40 +4442,7 @@ bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI, if (!KnownVal) { auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg()); - switch (Pred) { - default: - llvm_unreachable("Unexpected G_ICMP predicate?"); - case CmpInst::ICMP_EQ: - KnownVal = KnownBits::eq(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_NE: - KnownVal = KnownBits::ne(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_SGE: - KnownVal = KnownBits::sge(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_SGT: - KnownVal = KnownBits::sgt(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_SLE: - KnownVal = KnownBits::sle(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_SLT: - KnownVal = KnownBits::slt(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_UGE: - KnownVal = KnownBits::uge(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_UGT: - KnownVal = KnownBits::ugt(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_ULE: - KnownVal = KnownBits::ule(KnownLHS, KnownRHS); - break; - case CmpInst::ICMP_ULT: - KnownVal = KnownBits::ult(KnownLHS, KnownRHS); - break; - } + KnownVal = ICmpInst::compare(KnownLHS, KnownRHS, Pred); } if (!KnownVal) diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 05e340ffa20a0..5b89a27126150 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -40,6 +40,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CheckedArithmetic.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/ModRef.h" #include "llvm/Support/TypeSize.h" @@ -3837,6 +3838,35 @@ bool FCmpInst::compare(const APFloat &LHS, const APFloat &RHS, } } +std::optional ICmpInst::compare(const KnownBits &LHS, + const KnownBits &RHS, + ICmpInst::Predicate Pred) { + switch (Pred) { + case ICmpInst::ICMP_EQ: + return KnownBits::eq(LHS, RHS); + case ICmpInst::ICMP_NE: + return KnownBits::ne(LHS, RHS); + case ICmpInst::ICMP_UGE: + return KnownBits::uge(LHS, RHS); + case ICmpInst::ICMP_UGT: + return KnownBits::ugt(LHS, RHS); + case ICmpInst::ICMP_ULE: + return KnownBits::ule(LHS, RHS); + case ICmpInst::ICMP_ULT: + return KnownBits::ult(LHS, RHS); + case ICmpInst::ICMP_SGE: + return KnownBits::sge(LHS, RHS); + case ICmpInst::ICMP_SGT: + return KnownBits::sgt(LHS, RHS); + case ICmpInst::ICMP_SLE: + return KnownBits::sle(LHS, RHS); + case ICmpInst::ICMP_SLT: + return KnownBits::slt(LHS, RHS); + default: + llvm_unreachable("Unexpected non-integer predicate."); + } +} + CmpInst::Predicate CmpInst::getFlippedSignednessPredicate(Predicate pred) { assert(CmpInst::isRelational(pred) && "Call only with non-equality predicates!");