Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,13 @@ bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
Instruction *OnPathTo,
DominatorTree *DT);

/// Convert an integer comparison with a constant RHS into an equivalent
/// form with the strictness flipped predicate. Return the new predicate and
/// corresponding constant RHS if possible. Otherwise return std::nullopt.
/// E.g., (icmp sgt X, 0) -> (icmp sle X, 1).
std::optional<std::pair<CmpPredicate, Constant *>>
getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C);

/// Specific patterns of select instructions we can match.
enum SelectPatternFlavor {
SPF_UNKNOWN = 0,
Expand Down
6 changes: 0 additions & 6 deletions llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
}

std::optional<std::pair<
CmpPredicate,
Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate
Pred,
Constant *C);

static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
// a ? b : false and a ? true : b are the canonical form of logical and/or.
// This includes !a ? b : false and !a ? true : b. Absorbing the not into
Expand Down
74 changes: 74 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8641,6 +8641,80 @@ SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
}
}

std::optional<std::pair<CmpPredicate, Constant *>>
llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
"Only for relational integer predicates.");
if (isa<UndefValue>(C))
return std::nullopt;

Type *Type = C->getType();
bool IsSigned = ICmpInst::isSigned(Pred);

CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
bool WillIncrement =
UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;

// Check if the constant operand can be safely incremented/decremented
// without overflowing/underflowing.
auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
};

Constant *SafeReplacementConstant = nullptr;
if (auto *CI = dyn_cast<ConstantInt>(C)) {
// Bail out if the constant can't be safely incremented/decremented.
if (!ConstantIsOk(CI))
return std::nullopt;
} else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
unsigned NumElts = FVTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return std::nullopt;

if (isa<UndefValue>(Elt))
continue;

// Bail out if we can't determine if this constant is min/max or if we
// know that this constant is min/max.
auto *CI = dyn_cast<ConstantInt>(Elt);
if (!CI || !ConstantIsOk(CI))
return std::nullopt;

if (!SafeReplacementConstant)
SafeReplacementConstant = CI;
}
} else if (isa<VectorType>(C->getType())) {
// Handle scalable splat
Value *SplatC = C->getSplatValue();
auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
// Bail out if the constant can't be safely incremented/decremented.
if (!CI || !ConstantIsOk(CI))
return std::nullopt;
} else {
// ConstantExpr?
return std::nullopt;
}

// It may not be safe to change a compare predicate in the presence of
// undefined elements, so replace those elements with the first safe constant
// that we found.
// TODO: in case of poison, it is safe; let's replace undefs only.
if (C->containsUndefOrPoisonElement()) {
assert(SafeReplacementConstant && "Replacement constant not set");
C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
}

CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);

// Increment or decrement the constant.
Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);

return std::make_pair(NewPred, NewC);
}

static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
FastMathFlags FMF,
Value *CmpLHS, Value *CmpRHS,
Expand Down
84 changes: 4 additions & 80 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2485,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
// icmp ule i64 (shl X, 32), 8589934592 ->
// icmp ule i32 (trunc X, i32), 2 ->
// icmp ult i32 (trunc X, i32), 3
if (auto FlippedStrictness =
InstCombiner::getFlippedStrictnessPredicateAndConstant(
Pred, ConstantInt::get(ShType->getContext(), C))) {
if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(
Pred, ConstantInt::get(ShType->getContext(), C))) {
CmpPred = FlippedStrictness->first;
RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue();
}
Expand Down Expand Up @@ -3280,8 +3279,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
// x sgt C-1 <--> x sge C <--> not(x slt C)
auto FlippedStrictness =
InstCombiner::getFlippedStrictnessPredicateAndConstant(
PredB, cast<Constant>(RHS2));
getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
if (!FlippedStrictness)
return false;
assert(FlippedStrictness->first == ICmpInst::ICMP_SGE &&
Expand Down Expand Up @@ -6908,79 +6906,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
return nullptr;
}

std::optional<std::pair<CmpPredicate, Constant *>>
InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred,
Constant *C) {
assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
"Only for relational integer predicates.");

Type *Type = C->getType();
bool IsSigned = ICmpInst::isSigned(Pred);

CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
bool WillIncrement =
UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;

// Check if the constant operand can be safely incremented/decremented
// without overflowing/underflowing.
auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
};

Constant *SafeReplacementConstant = nullptr;
if (auto *CI = dyn_cast<ConstantInt>(C)) {
// Bail out if the constant can't be safely incremented/decremented.
if (!ConstantIsOk(CI))
return std::nullopt;
} else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
unsigned NumElts = FVTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return std::nullopt;

if (isa<UndefValue>(Elt))
continue;

// Bail out if we can't determine if this constant is min/max or if we
// know that this constant is min/max.
auto *CI = dyn_cast<ConstantInt>(Elt);
if (!CI || !ConstantIsOk(CI))
return std::nullopt;

if (!SafeReplacementConstant)
SafeReplacementConstant = CI;
}
} else if (isa<VectorType>(C->getType())) {
// Handle scalable splat
Value *SplatC = C->getSplatValue();
auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
// Bail out if the constant can't be safely incremented/decremented.
if (!CI || !ConstantIsOk(CI))
return std::nullopt;
} else {
// ConstantExpr?
return std::nullopt;
}

// It may not be safe to change a compare predicate in the presence of
// undefined elements, so replace those elements with the first safe constant
// that we found.
// TODO: in case of poison, it is safe; let's replace undefs only.
if (C->containsUndefOrPoisonElement()) {
assert(SafeReplacementConstant && "Replacement constant not set");
C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
}

CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);

// Increment or decrement the constant.
Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);

return std::make_pair(NewPred, NewC);
}

/// If we have an icmp le or icmp ge instruction with a constant operand, turn
/// it into the appropriate icmp lt or icmp gt instruction. This transform
/// allows them to be folded in visitICmpInst.
Expand All @@ -6996,8 +6921,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
if (!Op1C)
return nullptr;

auto FlippedStrictness =
InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
if (!FlippedStrictness)
return nullptr;

Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1689,8 +1689,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
return nullptr;

// Check the constant we'd have with flipped-strictness predicate.
auto FlippedStrictness =
InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
if (!FlippedStrictness)
return nullptr;

Expand Down Expand Up @@ -1970,8 +1969,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
Value *RHS;
SelectPatternFlavor SPF;
const DataLayout &DL = BOp->getDataLayout();
auto Flipped =
InstCombiner::getFlippedStrictnessPredicateAndConstant(Predicate, C1);
auto Flipped = getFlippedStrictnessPredicateAndConstant(Predicate, C1);

if (C3 == ConstantFoldBinaryOpOperands(Opcode, C1, C2, DL)) {
SPF = getSelectPattern(Predicate).Flavor;
Expand Down
Loading