Skip to content

Commit ee59139

Browse files
committed
[InstCombine] Modify foldSelectICmpEq to only handle more useful and simple cases.
The original intent of the folds (from #71792) where to handle selects where the condition was a bitwise operator compared with zero. During review of the original PR (#73362), however, we ended up over generalizing at the expense of code complexity, and ironically in a way that didn't actually fix the issue as reported. The goal of this PR is to simplify the code and only handle the compares with zero cases that actually show up in the real world. New code handles three cases: 1) `X & Y == 0` implies `X | Y == X ^ Y` thus: - `X & Y == 0 ? X |/^/+ Y : X |/^/+ Y` -> `X |/^/+ Y` (the false arm) - https://alive2.llvm.org/ce/z/jjcduh 2) `X | Y == 0` implies `X == Y == 0` thus for `Op0` and `Op1` s.t `0 Op0 0 == 0 Op1 0 == 0`: - `X & Y == 0 ? X Op0 Y : X Op1 Y` -> `X Op1 Y` - `X & Y == 0 ? 0 : X Op1 Y` -> `X Op1 Y` - https://alive2.llvm.org/ce/z/RBuFQE 3) `X ^ Y == 0` (`X == Y`) implies `X | Y == X & Y`: - `X ^ Y == 0 ? X | Y : X & Y` -> `X & Y` - `X ^ Y == 0 ? X & Y : X | Y` -> `X | Y` - https://alive2.llvm.org/ce/z/SJskbz
1 parent c56b743 commit ee59139

File tree

2 files changed

+183
-159
lines changed

2 files changed

+183
-159
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 67 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,8 +1821,12 @@ static Instruction *foldSelectWithExtremeEqCond(Value *CmpLHS, Value *CmpRHS,
18211821
return new ICmpInst(Pred, CmpLHS, B);
18221822
}
18231823

1824+
// Fold (X Op0 Y) == 0 ? (X Op1 Y) : (X Op2 Y)
1825+
// -> (X Op2 Y)
1826+
// By proving that `(X Op1 Y) == (X Op2 Y)` in the context of `(X Op0 Y) == 0`.
18241827
static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
18251828
InstCombinerImpl &IC) {
1829+
18261830
ICmpInst::Predicate Pred = ICI->getPredicate();
18271831
if (!ICmpInst::isEquality(Pred))
18281832
return nullptr;
@@ -1835,96 +1839,79 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
18351839
if (Pred == ICmpInst::ICMP_NE)
18361840
std::swap(TrueVal, FalseVal);
18371841

1838-
if (Instruction *Res =
1839-
foldSelectWithExtremeEqCond(CmpLHS, CmpRHS, TrueVal, FalseVal))
1840-
return Res;
1842+
if (auto *R = foldSelectWithExtremeEqCond(CmpLHS, CmpRHS, TrueVal, FalseVal))
1843+
return R;
18411844

1842-
// Transform (X == C) ? X : Y -> (X == C) ? C : Y
1843-
// specific handling for Bitwise operation.
1844-
// x&y -> (x|y) ^ (x^y) or (x|y) & ~(x^y)
1845-
// x|y -> (x&y) | (x^y) or (x&y) ^ (x^y)
1846-
// x^y -> (x|y) ^ (x&y) or (x|y) & ~(x&y)
18471845
Value *X, *Y;
1848-
if (!match(CmpLHS, m_BitwiseLogic(m_Value(X), m_Value(Y))) ||
1849-
!match(TrueVal, m_c_BitwiseLogic(m_Specific(X), m_Specific(Y))))
1850-
return nullptr;
1851-
1852-
const unsigned AndOps = Instruction::And, OrOps = Instruction::Or,
1853-
XorOps = Instruction::Xor, NoOps = 0;
1854-
enum NotMask { None = 0, NotInner, NotRHS };
1855-
1856-
auto matchFalseVal = [&](unsigned OuterOpc, unsigned InnerOpc,
1857-
unsigned NotMask) {
1858-
auto matchInner = m_c_BinOp(InnerOpc, m_Specific(X), m_Specific(Y));
1859-
if (OuterOpc == NoOps)
1860-
return match(CmpRHS, m_Zero()) && match(FalseVal, matchInner);
1861-
1862-
if (NotMask == NotInner) {
1863-
return match(FalseVal, m_c_BinOp(OuterOpc, m_NotForbidPoison(matchInner),
1864-
m_Specific(CmpRHS)));
1865-
} else if (NotMask == NotRHS) {
1866-
return match(FalseVal, m_c_BinOp(OuterOpc, matchInner,
1867-
m_NotForbidPoison(m_Specific(CmpRHS))));
1868-
} else {
1869-
return match(FalseVal,
1870-
m_c_BinOp(OuterOpc, matchInner, m_Specific(CmpRHS)));
1871-
}
1872-
};
1873-
1874-
// (X&Y)==C ? X|Y : X^Y -> (X^Y)|C : X^Y or (X^Y)^ C : X^Y
1875-
// (X&Y)==C ? X^Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y
1876-
if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) {
1877-
if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
1878-
// (X&Y)==C ? X|Y : (X^Y)|C -> (X^Y)|C : (X^Y)|C -> (X^Y)|C
1879-
// (X&Y)==C ? X|Y : (X^Y)^C -> (X^Y)^C : (X^Y)^C -> (X^Y)^C
1880-
if (matchFalseVal(OrOps, XorOps, None) ||
1881-
matchFalseVal(XorOps, XorOps, None))
1882-
return IC.replaceInstUsesWith(SI, FalseVal);
1883-
} else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) {
1884-
// (X&Y)==C ? X^Y : (X|Y)^ C -> (X|Y)^ C : (X|Y)^ C -> (X|Y)^ C
1885-
// (X&Y)==C ? X^Y : (X|Y)&~C -> (X|Y)&~C : (X|Y)&~C -> (X|Y)&~C
1886-
if (matchFalseVal(XorOps, OrOps, None) ||
1887-
matchFalseVal(AndOps, OrOps, NotRHS))
1846+
if (match(CmpRHS, m_Zero())) {
1847+
// (X & Y) == 0 ? X |/^/+ Y : X |/^/+ Y -> X |/^/+ Y (false arm)
1848+
// `(X & Y) == 0` implies no common bits which means:
1849+
// `X ^ Y == X | Y == X + Y`
1850+
// https://alive2.llvm.org/ce/z/jjcduh
1851+
if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) {
1852+
auto MatchAddOrXor =
1853+
m_CombineOr(m_c_Add(m_Specific(X), m_Specific(Y)),
1854+
m_CombineOr(m_c_Or(m_Specific(X), m_Specific(Y)),
1855+
m_c_Xor(m_Specific(X), m_Specific(Y))));
1856+
if (match(TrueVal, MatchAddOrXor) && match(FalseVal, MatchAddOrXor))
18881857
return IC.replaceInstUsesWith(SI, FalseVal);
18891858
}
1890-
}
18911859

1892-
// (X|Y)==C ? X&Y : X^Y -> (X^Y)^C : X^Y or ~(X^Y)&C : X^Y
1893-
// (X|Y)==C ? X^Y : X&Y -> (X&Y)^C : X&Y or ~(X&Y)&C : X&Y
1894-
if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y)))) {
1895-
if (match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y)))) {
1896-
// (X|Y)==C ? X&Y: (X^Y)^C -> (X^Y)^C: (X^Y)^C -> (X^Y)^C
1897-
// (X|Y)==C ? X&Y:~(X^Y)&C ->~(X^Y)&C:~(X^Y)&C -> ~(X^Y)&C
1898-
if (matchFalseVal(XorOps, XorOps, None) ||
1899-
matchFalseVal(AndOps, XorOps, NotInner))
1900-
return IC.replaceInstUsesWith(SI, FalseVal);
1901-
} else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) {
1902-
// (X|Y)==C ? X^Y : (X&Y)^C -> (X&Y)^C : (X&Y)^C -> (X&Y)^C
1903-
// (X|Y)==C ? X^Y :~(X&Y)&C -> ~(X&Y)&C :~(X&Y)&C -> ~(X&Y)&C
1904-
if (matchFalseVal(XorOps, AndOps, None) ||
1905-
matchFalseVal(AndOps, AndOps, NotInner))
1906-
return IC.replaceInstUsesWith(SI, FalseVal);
1907-
}
1908-
}
1860+
// (X | Y) == 0 ? X Op0 Y : X Op1 Y -> X Op1 Y
1861+
// For any `Op0` and `Op1` that are zero when `X` and `Y` are zero.
1862+
// https://alive2.llvm.org/ce/z/azHzBW
1863+
if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
1864+
(match(TrueVal, m_c_BinOp(m_Specific(X), m_Specific(Y))) ||
1865+
// In true arm we can also accept just `0`.
1866+
match(TrueVal, m_Zero())) &&
1867+
match(FalseVal, m_c_BinOp(m_Specific(X), m_Specific(Y)))) {
1868+
auto IsOpcZeroWithZeros = [](Value *V) {
1869+
auto *I = dyn_cast<Instruction>(V);
1870+
if (!I)
1871+
return false;
1872+
switch (I->getOpcode()) {
1873+
case Instruction::And:
1874+
case Instruction::Or:
1875+
case Instruction::Xor:
1876+
case Instruction::Mul:
1877+
case Instruction::Add:
1878+
case Instruction::Sub:
1879+
case Instruction::Shl:
1880+
case Instruction::AShr:
1881+
case Instruction::LShr:
1882+
return true;
1883+
default:
1884+
return false;
1885+
}
1886+
};
19091887

1910-
// (X^Y)==C ? X&Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y
1911-
// (X^Y)==C ? X|Y : X&Y -> (X&Y)|C : X&Y or (X&Y)^ C : X&Y
1912-
if (match(CmpLHS, m_Xor(m_Value(X), m_Value(Y)))) {
1913-
if ((match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y))))) {
1914-
// (X^Y)==C ? X&Y : (X|Y)^C -> (X|Y)^C
1915-
// (X^Y)==C ? X&Y : (X|Y)&~C -> (X|Y)&~C
1916-
if (matchFalseVal(XorOps, OrOps, None) ||
1917-
matchFalseVal(AndOps, OrOps, NotRHS))
1918-
return IC.replaceInstUsesWith(SI, FalseVal);
1919-
} else if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
1920-
// (X^Y)==C ? (X|Y) : (X&Y)|C -> (X&Y)|C
1921-
// (X^Y)==C ? (X|Y) : (X&Y)^C -> (X&Y)^C
1922-
if (matchFalseVal(OrOps, AndOps, None) ||
1923-
matchFalseVal(XorOps, AndOps, None))
1888+
if ((match(TrueVal, m_Zero()) || IsOpcZeroWithZeros(TrueVal)) &&
1889+
IsOpcZeroWithZeros(FalseVal))
19241890
return IC.replaceInstUsesWith(SI, FalseVal);
19251891
}
19261892
}
1893+
// (X == Y) ? X | Y : X & Y
1894+
// (X == Y) ? X & Y : X | Y
1895+
// If `X == Y` then `X == Y == X | Y == X & Y`.
1896+
// NB: `X == Y` is canonicalization of `(X ^ Y) == 0`.
1897+
// https://alive2.llvm.org/ce/z/SJskbz
1898+
X = CmpLHS;
1899+
Y = CmpRHS;
1900+
auto MatchOrAnd = m_CombineOr(m_c_Or(m_Specific(X), m_Specific(Y)),
1901+
m_c_And(m_Specific(X), m_Specific(Y)));
1902+
if (match(FalseVal, MatchOrAnd) &&
1903+
// In the true arm we can also just match `X` or `Y`.
1904+
(match(TrueVal, MatchOrAnd) || match(TrueVal, m_Specific(X)) ||
1905+
match(TrueVal, m_Specific(Y)))) {
1906+
// Can't preserve `or disjoint` here so rebuild.
1907+
auto *BO = dyn_cast<BinaryOperator>(FalseVal);
1908+
if (!BO)
1909+
return nullptr;
19271910

1911+
return IC.replaceInstUsesWith(
1912+
SI, IC.Builder.CreateBinOp(BO->getOpcode(), BO->getOperand(0),
1913+
BO->getOperand(1)));
1914+
}
19281915
return nullptr;
19291916
}
19301917

0 commit comments

Comments
 (0)