@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
7575
7676std::optional<DecomposedBitTest>
7777llvm::decomposeBitTestICmp (Value *LHS, Value *RHS, CmpInst::Predicate Pred,
78- bool LookThruTrunc, bool AllowNonZeroC) {
78+ bool LookThruTrunc, bool AllowNonZeroC,
79+ bool DecomposeBitMask) {
7980 using namespace PatternMatch ;
8081
8182 const APInt *OrigC;
82- if (!ICmpInst::isRelational (Pred) || !match (RHS, m_APIntAllowPoison (OrigC)))
83+ if ((ICmpInst::isEquality (Pred) && !DecomposeBitMask) ||
84+ !match (RHS, m_APIntAllowPoison (OrigC)))
8385 return std::nullopt ;
8486
8587 bool Inverted = false ;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
9799 }
98100
99101 DecomposedBitTest Result;
102+
100103 switch (Pred) {
101104 default :
102- llvm_unreachable ( " Unexpected predicate " ) ;
105+ return std:: nullopt ;
103106 case ICmpInst::ICMP_SLT: {
104107 // X < 0 is equivalent to (X & SignMask) != 0.
105108 if (C.isZero ()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
128131
129132 return std::nullopt ;
130133 }
131- case ICmpInst::ICMP_ULT:
134+ case ICmpInst::ICMP_ULT: {
132135 // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
133136 if (C.isPowerOf2 ()) {
134137 Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
147150
148151 return std::nullopt ;
149152 }
153+ case ICmpInst::ICMP_EQ:
154+ case ICmpInst::ICMP_NE: {
155+ assert (DecomposeBitMask);
156+ const APInt *AndC;
157+ Value *AndVal;
158+ if (match (LHS, m_And (m_Value (AndVal), m_APIntAllowPoison (AndC)))) {
159+ Result = {AndVal /* X*/ , Pred /* Pred*/ , *AndC /* Mask*/ , *OrigC /* C*/ };
160+ break ;
161+ }
162+
163+ return std::nullopt ;
164+ }
165+ }
150166
151167 if (!AllowNonZeroC && !Result.C .isZero ())
152168 return std::nullopt ;
@@ -159,23 +175,25 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
159175 Result.X = X;
160176 Result.Mask = Result.Mask .zext (X->getType ()->getScalarSizeInBits ());
161177 Result.C = Result.C .zext (X->getType ()->getScalarSizeInBits ());
162- } else {
178+ } else if (!Result. X ) {
163179 Result.X = LHS;
164180 }
165181
166182 return Result;
167183}
168184
169- std::optional<DecomposedBitTest>
170- llvm::decomposeBitTest (Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
185+ std::optional<DecomposedBitTest> llvm::decomposeBitTest (Value *Cond,
186+ bool LookThruTrunc,
187+ bool AllowNonZeroC,
188+ bool DecomposeBitMask) {
171189 using namespace PatternMatch ;
172190 if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
173191 // Don't allow pointers. Splat vectors are fine.
174192 if (!ICmp->getOperand (0 )->getType ()->isIntOrIntVectorTy ())
175193 return std::nullopt ;
176194 return decomposeBitTestICmp (ICmp->getOperand (0 ), ICmp->getOperand (1 ),
177195 ICmp->getPredicate (), LookThruTrunc,
178- AllowNonZeroC);
196+ AllowNonZeroC, DecomposeBitMask );
179197 }
180198 Value *X;
181199 if (Cond->getType ()->isIntOrIntVectorTy (1 ) &&
0 commit comments