Skip to content

Commit 93e555f

Browse files
goldsteinntstellar
authored andcommitted
[InstCombine] Fix buggy (mul X, Y) -> (shl X, Log2(Y)) transform PR62175
Bug was because we recognized patterns like `(shl 4, Z)` as a power of 2 we could take Log2 of (`2 + Z`), but doing `(shl X, (2 + Z))` can cause a poison shift. https://alive2.llvm.org/ce/z/yuJm_k The fix is to verify that `Log2(Y)` will be a non-poisonous shift amount. We can do this with: `nsw` flag: - https://alive2.llvm.org/ce/z/yyyJBr - https://alive2.llvm.org/ce/z/YgubD_ `nuw` flag: - https://alive2.llvm.org/ce/z/-4mpyV - https://alive2.llvm.org/ce/z/a6ik6r Prove `Y != 0`: - https://alive2.llvm.org/ce/z/ced4su - https://alive2.llvm.org/ce/z/X-JJHb Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D148609
1 parent ff9dc9c commit 93e555f

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,7 @@ static const unsigned MaxDepth = 6;
11211121
// actual instructions, otherwise return a non-null dummy value. Return nullptr
11221122
// on failure.
11231123
static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
1124-
bool DoFold) {
1124+
bool AssumeNonZero, bool DoFold) {
11251125
auto IfFold = [DoFold](function_ref<Value *()> Fn) {
11261126
if (!DoFold)
11271127
return reinterpret_cast<Value *>(-1);
@@ -1147,37 +1147,48 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
11471147
// FIXME: Require one use?
11481148
Value *X, *Y;
11491149
if (match(Op, m_ZExt(m_Value(X))))
1150-
if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
1150+
if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
11511151
return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
11521152

11531153
// log2(X << Y) -> log2(X) + Y
11541154
// FIXME: Require one use unless X is 1?
1155-
if (match(Op, m_Shl(m_Value(X), m_Value(Y))))
1156-
if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
1157-
return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
1155+
if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
1156+
auto *BO = cast<OverflowingBinaryOperator>(Op);
1157+
// nuw will be set if the `shl` is trivially non-zero.
1158+
if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap())
1159+
if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
1160+
return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
1161+
}
11581162

11591163
// log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
11601164
// FIXME: missed optimization: if one of the hands of select is/contains
11611165
// undef, just directly pick the other one.
11621166
// FIXME: can both hands contain undef?
11631167
// FIXME: Require one use?
11641168
if (SelectInst *SI = dyn_cast<SelectInst>(Op))
1165-
if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold))
1166-
if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold))
1169+
if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth,
1170+
AssumeNonZero, DoFold))
1171+
if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth,
1172+
AssumeNonZero, DoFold))
11671173
return IfFold([&]() {
11681174
return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
11691175
});
11701176

11711177
// log2(umin(X, Y)) -> umin(log2(X), log2(Y))
11721178
// log2(umax(X, Y)) -> umax(log2(X), log2(Y))
11731179
auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op);
1174-
if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned())
1175-
if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold))
1176-
if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold))
1180+
if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) {
1181+
// Use AssumeNonZero as false here. Otherwise we can hit case where
1182+
// log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
1183+
if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth,
1184+
/*AssumeNonZero*/ false, DoFold))
1185+
if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth,
1186+
/*AssumeNonZero*/ false, DoFold))
11771187
return IfFold([&]() {
1178-
return Builder.CreateBinaryIntrinsic(
1179-
MinMax->getIntrinsicID(), LogX, LogY);
1188+
return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX,
1189+
LogY);
11801190
});
1191+
}
11811192

11821193
return nullptr;
11831194
}
@@ -1297,8 +1308,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
12971308
}
12981309

12991310
// Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
1300-
if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) {
1301-
Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true);
1311+
if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
1312+
/*DoFold*/ false)) {
1313+
Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0,
1314+
/*AssumeNonZero*/ true, /*DoFold*/ true);
13021315
return replaceInstUsesWith(
13031316
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
13041317
}

0 commit comments

Comments
 (0)