Skip to content

Commit 8bf52d5

Browse files
committed
[InstCombine] Move foldLogOpOfMaskedICmps to make it possible to handle trunc to i1.
1 parent 570f030 commit 8bf52d5

File tree

1 file changed

+126
-106
lines changed

1 file changed

+126
-106
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 126 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -199,113 +199,132 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
199199
/// the right hand side as a pair.
200200
/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
201201
/// and PredR are their predicates, respectively.
202-
static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
203-
Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
204-
ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
205-
// Don't allow pointers. Splat vectors are fine.
206-
if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
207-
!RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
208-
return std::nullopt;
202+
static std::optional<std::pair<unsigned, unsigned>>
203+
getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
204+
Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
205+
ICmpInst::Predicate &PredR) {
209206

210-
// Here comes the tricky part:
211-
// LHS might be of the form L11 & L12 == X, X == L21 & L22,
212-
// and L11 & L12 == L21 & L22. The same goes for RHS.
213-
// Now we must find those components L** and R**, that are equal, so
214-
// that we can extract the parameters A, B, C, D, and E for the canonical
215-
// above.
216-
Value *L1 = LHS->getOperand(0);
217-
Value *L2 = LHS->getOperand(1);
218-
Value *L11, *L12, *L21, *L22;
219-
// Check whether the icmp can be decomposed into a bit test.
220-
if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
221-
L21 = L22 = L1 = nullptr;
222-
} else {
223-
// Look for ANDs in the LHS icmp.
224-
if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
225-
// Any icmp can be viewed as being trivially masked; if it allows us to
226-
// remove one, it's worth it.
227-
L11 = L1;
228-
L12 = Constant::getAllOnesValue(L1->getType());
229-
}
207+
Value *L1, *L11, *L12, *L2, *L21, *L22;
208+
if (auto *LHSCMP = dyn_cast<ICmpInst>(LHS)) {
209+
210+
// Don't allow pointers. Splat vectors are fine.
211+
if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
212+
return std::nullopt;
213+
214+
PredL = LHSCMP->getPredicate();
215+
216+
// Here comes the tricky part:
217+
// LHS might be of the form L11 & L12 == X, X == L21 & L22,
218+
// and L11 & L12 == L21 & L22. The same goes for RHS.
219+
// Now we must find those components L** and R**, that are equal, so
220+
// that we can extract the parameters A, B, C, D, and E for the canonical
221+
// above.
222+
L1 = LHSCMP->getOperand(0);
223+
L2 = LHSCMP->getOperand(1);
224+
// Check whether the icmp can be decomposed into a bit test.
225+
if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
226+
L21 = L22 = L1 = nullptr;
227+
} else {
228+
// Look for ANDs in the LHS icmp.
229+
if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
230+
// Any icmp can be viewed as being trivially masked; if it allows us to
231+
// remove one, it's worth it.
232+
L11 = L1;
233+
L12 = Constant::getAllOnesValue(L1->getType());
234+
}
230235

231-
if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
232-
L21 = L2;
233-
L22 = Constant::getAllOnesValue(L2->getType());
236+
if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
237+
L21 = L2;
238+
L22 = Constant::getAllOnesValue(L2->getType());
239+
}
234240
}
235-
}
241+
// Bail if LHS was a icmp that can't be decomposed into an equality.
242+
if (!ICmpInst::isEquality(PredL))
243+
return std::nullopt;
236244

237-
// Bail if LHS was a icmp that can't be decomposed into an equality.
238-
if (!ICmpInst::isEquality(PredL))
245+
} else {
239246
return std::nullopt;
247+
}
240248

241-
Value *R1 = RHS->getOperand(0);
242-
Value *R2 = RHS->getOperand(1);
243249
Value *R11, *R12;
244-
bool Ok = false;
245-
if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
246-
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
247-
A = R11;
248-
D = R12;
249-
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
250-
A = R12;
251-
D = R11;
252-
} else {
250+
if (auto *RHSCMP = dyn_cast<ICmpInst>(RHS)) {
251+
252+
// Don't allow pointers. Splat vectors are fine.
253+
if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
253254
return std::nullopt;
254-
}
255-
E = R2;
256-
R1 = nullptr;
257-
Ok = true;
258-
} else {
259-
if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
260-
// As before, model no mask as a trivial mask if it'll let us do an
261-
// optimization.
262-
R11 = R1;
263-
R12 = Constant::getAllOnesValue(R1->getType());
264-
}
265255

266-
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
267-
A = R11;
268-
D = R12;
269-
E = R2;
270-
Ok = true;
271-
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
272-
A = R12;
273-
D = R11;
256+
PredR = RHSCMP->getPredicate();
257+
258+
Value *R1 = RHSCMP->getOperand(0);
259+
Value *R2 = RHSCMP->getOperand(1);
260+
bool Ok = false;
261+
if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
262+
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
263+
A = R11;
264+
D = R12;
265+
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
266+
A = R12;
267+
D = R11;
268+
} else {
269+
return std::nullopt;
270+
}
274271
E = R2;
272+
R1 = nullptr;
275273
Ok = true;
276-
}
277-
278-
// Avoid matching against the -1 value we created for unmasked operand.
279-
if (Ok && match(A, m_AllOnes()))
280-
Ok = false;
281-
}
274+
} else {
275+
if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
276+
// As before, model no mask as a trivial mask if it'll let us do an
277+
// optimization.
278+
R11 = R1;
279+
R12 = Constant::getAllOnesValue(R1->getType());
280+
}
282281

283-
// Bail if RHS was a icmp that can't be decomposed into an equality.
284-
if (!ICmpInst::isEquality(PredR))
285-
return std::nullopt;
282+
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
283+
A = R11;
284+
D = R12;
285+
E = R2;
286+
Ok = true;
287+
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
288+
A = R12;
289+
D = R11;
290+
E = R2;
291+
Ok = true;
292+
}
286293

287-
// Look for ANDs on the right side of the RHS icmp.
288-
if (!Ok) {
289-
if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
290-
R11 = R2;
291-
R12 = Constant::getAllOnesValue(R2->getType());
294+
// Avoid matching against the -1 value we created for unmasked operand.
295+
if (Ok && match(A, m_AllOnes()))
296+
Ok = false;
292297
}
293298

294-
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
295-
A = R11;
296-
D = R12;
297-
E = R1;
298-
Ok = true;
299-
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
300-
A = R12;
301-
D = R11;
302-
E = R1;
303-
Ok = true;
304-
} else {
299+
// Bail if RHS was a icmp that can't be decomposed into an equality.
300+
if (!ICmpInst::isEquality(PredR))
305301
return std::nullopt;
306-
}
307302

308-
assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
303+
// Look for ANDs on the right side of the RHS icmp.
304+
if (!Ok) {
305+
if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
306+
R11 = R2;
307+
R12 = Constant::getAllOnesValue(R2->getType());
308+
}
309+
310+
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
311+
A = R11;
312+
D = R12;
313+
E = R1;
314+
Ok = true;
315+
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
316+
A = R12;
317+
D = R11;
318+
E = R1;
319+
Ok = true;
320+
} else {
321+
return std::nullopt;
322+
}
323+
324+
assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
325+
}
326+
} else {
327+
return std::nullopt;
309328
}
310329

311330
if (L11 == A) {
@@ -334,8 +353,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
334353
/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
335354
/// Also used for logical and/or, must be poison safe.
336355
static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
337-
ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
338-
Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
356+
Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
357+
ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
339358
InstCombiner::BuilderTy &Builder) {
340359
// We are given the canonical form:
341360
// (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -458,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
458477
// (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
459478
if (IsSuperSetOrEqual(BCst, DCst)) {
460479
// We can't guarantee that samesign hold after this fold.
461-
RHS->setSameSign(false);
480+
if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
481+
ICmp->setSameSign(false);
462482
return RHS;
463483
}
464484
// Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -467,7 +487,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
467487
assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
468488
if ((*BCst & ECst) != 0) {
469489
// We can't guarantee that samesign hold after this fold.
470-
RHS->setSameSign(false);
490+
if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
491+
ICmp->setSameSign(false);
471492
return RHS;
472493
}
473494
// Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -482,8 +503,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
482503
/// aren't of the common mask pattern type.
483504
/// Also used for logical and/or, must be poison safe.
484505
static Value *foldLogOpOfMaskedICmpsAsymmetric(
485-
ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
486-
Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
506+
Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
507+
Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
487508
unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
488509
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
489510
"Expected equality predicates for masked type of icmps.");
@@ -512,12 +533,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
512533

513534
/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
514535
/// into a single (icmp(A & X) ==/!= Y).
515-
static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
536+
static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
516537
bool IsLogical,
517538
InstCombiner::BuilderTy &Builder,
518539
const SimplifyQuery &Q) {
519540
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
520-
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
541+
ICmpInst::Predicate PredL, PredR;
521542
std::optional<std::pair<unsigned, unsigned>> MaskPair =
522543
getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
523544
if (!MaskPair)
@@ -1067,8 +1088,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
10671088
if (!JoinedByAnd)
10681089
return nullptr;
10691090
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
1070-
ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
1071-
CmpPred1 = Cmp1->getPredicate();
1091+
ICmpInst::Predicate CmpPred0, CmpPred1;
10721092
// Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
10731093
// 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
10741094
// SignMask) == 0).
@@ -3325,12 +3345,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
33253345
}
33263346
}
33273347

3328-
// handle (roughly):
3329-
// (icmp ne (A & B), C) | (icmp ne (A & D), E)
3330-
// (icmp eq (A & B), C) & (icmp eq (A & D), E)
3331-
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder, Q))
3332-
return V;
3333-
33343348
if (Value *V =
33353349
foldAndOrOfICmpEqConstantAndICmp(LHS, RHS, IsAnd, IsLogical, Builder))
33363350
return V;
@@ -3510,6 +3524,12 @@ Value *InstCombinerImpl::foldBooleanAndOr(Value *LHS, Value *RHS,
35103524
if (Value *Res = foldAndOrOfICmps(LHSCmp, RHSCmp, I, IsAnd, IsLogical))
35113525
return Res;
35123526

3527+
/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
3528+
/// into a single (icmp(A & X) ==/!= Y).
3529+
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder,
3530+
SQ.getWithInstruction(&I)))
3531+
return V;
3532+
35133533
if (auto *LHSCmp = dyn_cast<FCmpInst>(LHS))
35143534
if (auto *RHSCmp = dyn_cast<FCmpInst>(RHS))
35153535
if (Value *Res = foldLogicOfFCmps(LHSCmp, RHSCmp, IsAnd, IsLogical))

0 commit comments

Comments
 (0)