Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions llvm/include/llvm/Analysis/CmpInstAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,28 @@ namespace llvm {
/// Represents the operation icmp (X & Mask) pred C, where pred can only be
/// eq or ne.
struct DecomposedBitTest {
Value *X;
Value *X = nullptr;
CmpInst::Predicate Pred;
APInt Mask;
APInt C;
};

/// Decompose an icmp into the form ((X & Mask) pred C) if possible.
/// Unless \p AllowNonZeroC is true, C will always be 0.
/// Unless \p AllowNonZeroC is true, C will always be 0. If \p
/// DecomposeBitMask is specified, then, for equality predicates, this will
/// decompose bitmasking (e.g. implemented via `and`).
std::optional<DecomposedBitTest>
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
bool LookThroughTrunc = true,
bool AllowNonZeroC = false);
bool LookThroughTrunc = true, bool AllowNonZeroC = false,
bool DecomposeBitMask = false);

/// Decompose an icmp into the form ((X & Mask) pred C) if
/// possible. Unless \p AllowNonZeroC is true, C will always be 0.
/// If \p DecomposeBitMask is specified, then, for equality predicates, this
/// will decompose bitmasking (e.g. implemented via `and`).
std::optional<DecomposedBitTest>
decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
bool AllowNonZeroC = false);
bool AllowNonZeroC = false, bool DecomposeBitMask = false);

} // end namespace llvm

Expand Down
34 changes: 26 additions & 8 deletions llvm/lib/Analysis/CmpInstAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,

std::optional<DecomposedBitTest>
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
bool LookThruTrunc, bool AllowNonZeroC) {
bool LookThruTrunc, bool AllowNonZeroC,
bool DecomposeBitMask) {
using namespace PatternMatch;

const APInt *OrigC;
if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
!match(RHS, m_APIntAllowPoison(OrigC)))
return std::nullopt;

bool Inverted = false;
Expand All @@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
}

DecomposedBitTest Result;

switch (Pred) {
default:
llvm_unreachable("Unexpected predicate");
return std::nullopt;
case ICmpInst::ICMP_SLT: {
// X < 0 is equivalent to (X & SignMask) != 0.
if (C.isZero()) {
Expand Down Expand Up @@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,

return std::nullopt;
}
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULT: {
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
if (C.isPowerOf2()) {
Result.Mask = -C;
Expand All @@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,

return std::nullopt;
}
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE: {
assert(DecomposeBitMask);
const APInt *AndC;
Value *AndVal;
if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
break;
}

return std::nullopt;
}
}

if (!AllowNonZeroC && !Result.C.isZero())
return std::nullopt;
Expand All @@ -159,23 +175,25 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
Result.X = X;
Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
} else {
} else if (!Result.X) {
Result.X = LHS;
}

return Result;
}

std::optional<DecomposedBitTest>
llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
bool LookThruTrunc,
bool AllowNonZeroC,
bool DecomposeBitMask) {
using namespace PatternMatch;
if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
// Don't allow pointers. Splat vectors are fine.
if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
return std::nullopt;
return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
ICmp->getPredicate(), LookThruTrunc,
AllowNonZeroC);
AllowNonZeroC, DecomposeBitMask);
}
Value *X;
if (Cond->getType()->isIntOrIntVectorTy(1) &&
Expand Down
14 changes: 4 additions & 10 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
APInt &UnsetBitsMask) -> bool {
CmpPredicate Pred = ICmp->getPredicate();
// Can it be decomposed into icmp eq (X & Mask), 0 ?
auto Res =
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
Pred, /*LookThroughTrunc=*/false);
auto Res = llvm::decomposeBitTestICmp(
ICmp->getOperand(0), ICmp->getOperand(1), Pred,
/*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
Copy link
Contributor

@andjo403 andjo403 Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
/*LookThroughTrunc=*/false, /*AllowNonZeroC=*/false,

do not think that this shall start to look through trunc as that case have special handling below

/*DecomposeBitMask=*/true);
if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
X = Res->X;
UnsetBitsMask = Res->Mask;
return true;
}

// Is it icmp eq (X & Mask), 0 already?
const APInt *Mask;
if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
Pred == ICmpInst::ICMP_EQ) {
UnsetBitsMask = *Mask;
return true;
}
return false;
};

Expand Down
32 changes: 15 additions & 17 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
Value *CmpLHS, *CmpRHS;

if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
if (ICmpInst::isEquality(Pred)) {
if (!match(CmpRHS, m_Zero()))
return nullptr;
auto Res = decomposeBitTestICmp(
CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
/*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);

V = CmpLHS;
const APInt *AndRHS;
if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
return nullptr;
if (!Res)
return nullptr;

AndMask = *AndRHS;
} else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
AndMask = Res->Mask;
V = CmpLHS;
AndMask = Res->Mask;

if (!ICmpInst::isEquality(Pred)) {
V = Res->X;
KnownBits Known =
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
AndMask &= Known.getMaxValue();
if (!AndMask.isPowerOf2())
return nullptr;

Pred = Res->Pred;
CreateAnd = true;
} else {
return nullptr;
}

Pred = Res->Pred;

if (!AndMask.isPowerOf2())
return nullptr;

} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
V = Trunc->getOperand(0);
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
Expand Down
14 changes: 5 additions & 9 deletions llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
CurLoop))));
};
auto MatchConstantBitMask = [&]() {
return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
match(CmpLHS, m_And(m_Value(CurrX),
m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
(BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
};

auto MatchDecomposableConstantBitMask = [&]() {
auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
auto Res = llvm::decomposeBitTestICmp(
CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
/*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
if (Res && Res->Mask.isPowerOf2()) {
assert(ICmpInst::isEquality(Res->Pred));
Pred = Res->Pred;
Expand All @@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
return false;
};

if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
!MatchDecomposableConstantBitMask()) {
if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
return false;
}
Expand Down
Loading