diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h index 8aa024a72afc8..b4918c2d1e8a1 100644 --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -1102,6 +1102,13 @@ bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, Instruction *OnPathTo, DominatorTree *DT); +/// Convert an integer comparison with a constant RHS into an equivalent +/// form with the strictness flipped predicate. Return the new predicate and +/// corresponding constant RHS if possible. Otherwise return std::nullopt. +/// E.g., (icmp sgt X, 0) -> (icmp sle X, 1). +std::optional> +getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C); + /// Specific patterns of select instructions we can match. enum SelectPatternFlavor { SPF_UNKNOWN = 0, diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h index 71592058e3456..fa6b60cba15aa 100644 --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -184,12 +184,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner { return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); } - std::optional> static getFlippedStrictnessPredicateAndConstant(CmpPredicate - Pred, - Constant *C); - static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) { // a ? b : false and a ? true : b are the canonical form of logical and/or. // This includes !a ? b : false and !a ? true : b. Absorbing the not into diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 2f6e869ae7b73..0eb43dd581acc 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -8641,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred, } } +std::optional> +llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) { + assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && + "Only for relational integer predicates."); + if (isa(C)) + return std::nullopt; + + Type *Type = C->getType(); + bool IsSigned = ICmpInst::isSigned(Pred); + + CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); + bool WillIncrement = + UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; + + // Check if the constant operand can be safely incremented/decremented + // without overflowing/underflowing. + auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { + return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); + }; + + Constant *SafeReplacementConstant = nullptr; + if (auto *CI = dyn_cast(C)) { + // Bail out if the constant can't be safely incremented/decremented. + if (!ConstantIsOk(CI)) + return std::nullopt; + } else if (auto *FVTy = dyn_cast(Type)) { + unsigned NumElts = FVTy->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt) + return std::nullopt; + + if (isa(Elt)) + continue; + + // Bail out if we can't determine if this constant is min/max or if we + // know that this constant is min/max. + auto *CI = dyn_cast(Elt); + if (!CI || !ConstantIsOk(CI)) + return std::nullopt; + + if (!SafeReplacementConstant) + SafeReplacementConstant = CI; + } + } else if (isa(C->getType())) { + // Handle scalable splat + Value *SplatC = C->getSplatValue(); + auto *CI = dyn_cast_or_null(SplatC); + // Bail out if the constant can't be safely incremented/decremented. + if (!CI || !ConstantIsOk(CI)) + return std::nullopt; + } else { + // ConstantExpr? + return std::nullopt; + } + + // It may not be safe to change a compare predicate in the presence of + // undefined elements, so replace those elements with the first safe constant + // that we found. + // TODO: in case of poison, it is safe; let's replace undefs only. + if (C->containsUndefOrPoisonElement()) { + assert(SafeReplacementConstant && "Replacement constant not set"); + C = Constant::replaceUndefsWith(C, SafeReplacementConstant); + } + + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); + + // Increment or decrement the constant. + Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); + Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); + + return std::make_pair(NewPred, NewC); +} + static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, FastMathFlags FMF, Value *CmpLHS, Value *CmpRHS, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 8b23583c51063..c2d659035877e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2485,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, // icmp ule i64 (shl X, 32), 8589934592 -> // icmp ule i32 (trunc X, i32), 2 -> // icmp ult i32 (trunc X, i32), 3 - if (auto FlippedStrictness = - InstCombiner::getFlippedStrictnessPredicateAndConstant( - Pred, ConstantInt::get(ShType->getContext(), C))) { + if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant( + Pred, ConstantInt::get(ShType->getContext(), C))) { CmpPred = FlippedStrictness->first; RHSC = cast(FlippedStrictness->second)->getValue(); } @@ -3280,8 +3279,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, if (PredB == ICmpInst::ICMP_SGT && isa(RHS2)) { // x sgt C-1 <--> x sge C <--> not(x slt C) auto FlippedStrictness = - InstCombiner::getFlippedStrictnessPredicateAndConstant( - PredB, cast(RHS2)); + getFlippedStrictnessPredicateAndConstant(PredB, cast(RHS2)); if (!FlippedStrictness) return false; assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && @@ -6908,79 +6906,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { return nullptr; } -std::optional> -InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, - Constant *C) { - assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && - "Only for relational integer predicates."); - - Type *Type = C->getType(); - bool IsSigned = ICmpInst::isSigned(Pred); - - CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); - bool WillIncrement = - UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; - - // Check if the constant operand can be safely incremented/decremented - // without overflowing/underflowing. - auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { - return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); - }; - - Constant *SafeReplacementConstant = nullptr; - if (auto *CI = dyn_cast(C)) { - // Bail out if the constant can't be safely incremented/decremented. - if (!ConstantIsOk(CI)) - return std::nullopt; - } else if (auto *FVTy = dyn_cast(Type)) { - unsigned NumElts = FVTy->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = C->getAggregateElement(i); - if (!Elt) - return std::nullopt; - - if (isa(Elt)) - continue; - - // Bail out if we can't determine if this constant is min/max or if we - // know that this constant is min/max. - auto *CI = dyn_cast(Elt); - if (!CI || !ConstantIsOk(CI)) - return std::nullopt; - - if (!SafeReplacementConstant) - SafeReplacementConstant = CI; - } - } else if (isa(C->getType())) { - // Handle scalable splat - Value *SplatC = C->getSplatValue(); - auto *CI = dyn_cast_or_null(SplatC); - // Bail out if the constant can't be safely incremented/decremented. - if (!CI || !ConstantIsOk(CI)) - return std::nullopt; - } else { - // ConstantExpr? - return std::nullopt; - } - - // It may not be safe to change a compare predicate in the presence of - // undefined elements, so replace those elements with the first safe constant - // that we found. - // TODO: in case of poison, it is safe; let's replace undefs only. - if (C->containsUndefOrPoisonElement()) { - assert(SafeReplacementConstant && "Replacement constant not set"); - C = Constant::replaceUndefsWith(C, SafeReplacementConstant); - } - - CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); - - // Increment or decrement the constant. - Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); - Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); - - return std::make_pair(NewPred, NewC); -} - /// If we have an icmp le or icmp ge instruction with a constant operand, turn /// it into the appropriate icmp lt or icmp gt instruction. This transform /// allows them to be folded in visitICmpInst. @@ -6996,8 +6921,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { if (!Op1C) return nullptr; - auto FlippedStrictness = - InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C); + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C); if (!FlippedStrictness) return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 7fd91c72a2fb0..eca518aa64070 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1689,8 +1689,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, return nullptr; // Check the constant we'd have with flipped-strictness predicate. - auto FlippedStrictness = - InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0); + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0); if (!FlippedStrictness) return nullptr; @@ -1970,8 +1969,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal, Value *RHS; SelectPatternFlavor SPF; const DataLayout &DL = BOp->getDataLayout(); - auto Flipped = - InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1); + auto Flipped = getFlippedStrictnessPredicateAndConstant(Predicate, C1); if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, DL)) { SPF = getSelectPattern(Predicate).Flavor;