Skip to content

Commit 740a37e

Browse files
committed
[CmpInstAnalysis] Decompose icmp eq (and x, C) C2
Change-Id: I1dd786a4652ccd2e486db6903e16e58ffa1a7959
1 parent ae0aa2d commit 740a37e

File tree

5 files changed

+59
-49
lines changed

5 files changed

+59
-49
lines changed

llvm/include/llvm/Analysis/CmpInstAnalysis.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,24 +95,28 @@ namespace llvm {
9595
/// Represents the operation icmp (X & Mask) pred C, where pred can only be
9696
/// eq or ne.
9797
struct DecomposedBitTest {
98-
Value *X;
98+
Value *X = nullptr;
9999
CmpInst::Predicate Pred;
100100
APInt Mask;
101101
APInt C;
102102
};
103103

104104
/// Decompose an icmp into the form ((X & Mask) pred C) if possible.
105-
/// Unless \p AllowNonZeroC is true, C will always be 0.
105+
/// Unless \p AllowNonZeroC is true, C will always be 0. If \p
106+
/// DecomposeBitMask is specified, then, for equality predicates, this will
107+
/// decompose bitmasking (e.g. implemented via `and`).
106108
std::optional<DecomposedBitTest>
107109
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
108-
bool LookThroughTrunc = true,
109-
bool AllowNonZeroC = false);
110+
bool LookThroughTrunc = true, bool AllowNonZeroC = false,
111+
bool DecomposeBitMask = false);
110112

111113
/// Decompose an icmp into the form ((X & Mask) pred C) if
112114
/// possible. Unless \p AllowNonZeroC is true, C will always be 0.
115+
/// If \p DecomposeBitMask is specified, then, for equality predicates, this
116+
/// will decompose bitmasking (e.g. implemented via `and`).
113117
std::optional<DecomposedBitTest>
114118
decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
115-
bool AllowNonZeroC = false);
119+
bool AllowNonZeroC = false, bool DecomposeBitMask = false);
116120

117121
} // end namespace llvm
118122

llvm/lib/Analysis/CmpInstAnalysis.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
7575

7676
std::optional<DecomposedBitTest>
7777
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
78-
bool LookThruTrunc, bool AllowNonZeroC) {
78+
bool LookThruTrunc, bool AllowNonZeroC,
79+
bool DecomposeBitMask) {
7980
using namespace PatternMatch;
8081

8182
const APInt *OrigC;
82-
if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
83+
if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
84+
!match(RHS, m_APIntAllowPoison(OrigC)))
8385
return std::nullopt;
8486

8587
bool Inverted = false;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
9799
}
98100

99101
DecomposedBitTest Result;
102+
100103
switch (Pred) {
101104
default:
102-
llvm_unreachable("Unexpected predicate");
105+
return std::nullopt;
103106
case ICmpInst::ICMP_SLT: {
104107
// X < 0 is equivalent to (X & SignMask) != 0.
105108
if (C.isZero()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
128131

129132
return std::nullopt;
130133
}
131-
case ICmpInst::ICMP_ULT:
134+
case ICmpInst::ICMP_ULT: {
132135
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
133136
if (C.isPowerOf2()) {
134137
Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
147150

148151
return std::nullopt;
149152
}
153+
case ICmpInst::ICMP_EQ:
154+
case ICmpInst::ICMP_NE: {
155+
assert(DecomposeBitMask);
156+
const APInt *AndC;
157+
Value *AndVal;
158+
if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
159+
Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
160+
break;
161+
}
162+
163+
return std::nullopt;
164+
}
165+
}
150166

151167
if (!AllowNonZeroC && !Result.C.isZero())
152168
return std::nullopt;
@@ -159,23 +175,25 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
159175
Result.X = X;
160176
Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
161177
Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
162-
} else {
178+
} else if (!Result.X) {
163179
Result.X = LHS;
164180
}
165181

166182
return Result;
167183
}
168184

169-
std::optional<DecomposedBitTest>
170-
llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
185+
std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
186+
bool LookThruTrunc,
187+
bool AllowNonZeroC,
188+
bool DecomposeBitMask) {
171189
using namespace PatternMatch;
172190
if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
173191
// Don't allow pointers. Splat vectors are fine.
174192
if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
175193
return std::nullopt;
176194
return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
177195
ICmp->getPredicate(), LookThruTrunc,
178-
AllowNonZeroC);
196+
AllowNonZeroC, DecomposeBitMask);
179197
}
180198
Value *X;
181199
if (Cond->getType()->isIntOrIntVectorTy(1) &&

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
875875
APInt &UnsetBitsMask) -> bool {
876876
CmpPredicate Pred = ICmp->getPredicate();
877877
// Can it be decomposed into icmp eq (X & Mask), 0 ?
878-
auto Res =
879-
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
880-
Pred, /*LookThroughTrunc=*/false);
878+
auto Res = llvm::decomposeBitTestICmp(
879+
ICmp->getOperand(0), ICmp->getOperand(1), Pred,
880+
/*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
881+
/*DecomposeBitMask=*/true);
881882
if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
882883
X = Res->X;
883884
UnsetBitsMask = Res->Mask;
884885
return true;
885886
}
886887

887-
// Is it icmp eq (X & Mask), 0 already?
888-
const APInt *Mask;
889-
if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
890-
Pred == ICmpInst::ICMP_EQ) {
891-
UnsetBitsMask = *Mask;
892-
return true;
893-
}
894888
return false;
895889
};
896890

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
37303730
Value *CmpLHS, *CmpRHS;
37313731

37323732
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
3733-
if (ICmpInst::isEquality(Pred)) {
3734-
if (!match(CmpRHS, m_Zero()))
3735-
return nullptr;
3733+
auto Res = decomposeBitTestICmp(
3734+
CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
3735+
/*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
37363736

3737-
V = CmpLHS;
3738-
const APInt *AndRHS;
3739-
if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
3740-
return nullptr;
3737+
if (!Res)
3738+
return nullptr;
37413739

3742-
AndMask = *AndRHS;
3743-
} else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
3744-
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
3745-
AndMask = Res->Mask;
3740+
V = CmpLHS;
3741+
AndMask = Res->Mask;
3742+
3743+
if (!ICmpInst::isEquality(Pred)) {
37463744
V = Res->X;
37473745
KnownBits Known =
37483746
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
37493747
AndMask &= Known.getMaxValue();
3750-
if (!AndMask.isPowerOf2())
3751-
return nullptr;
3752-
3753-
Pred = Res->Pred;
37543748
CreateAnd = true;
3755-
} else {
3756-
return nullptr;
37573749
}
3750+
3751+
Pred = Res->Pred;
3752+
3753+
if (!AndMask.isPowerOf2())
3754+
return nullptr;
3755+
37583756
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
37593757
V = Trunc->getOperand(0);
37603758
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);

llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
27612761
m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
27622762
CurLoop))));
27632763
};
2764-
auto MatchConstantBitMask = [&]() {
2765-
return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
2766-
match(CmpLHS, m_And(m_Value(CurrX),
2767-
m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
2768-
(BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
2769-
};
2764+
27702765
auto MatchDecomposableConstantBitMask = [&]() {
2771-
auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
2766+
auto Res = llvm::decomposeBitTestICmp(
2767+
CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
2768+
/*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
27722769
if (Res && Res->Mask.isPowerOf2()) {
27732770
assert(ICmpInst::isEquality(Res->Pred));
27742771
Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
27802777
return false;
27812778
};
27822779

2783-
if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
2784-
!MatchDecomposableConstantBitMask()) {
2780+
if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
27852781
LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
27862782
return false;
27872783
}

0 commit comments

Comments
 (0)