Skip to content

Commit d765fb9

Browse files
committed
[InstCombine] Detect uadd with overflow idiom
Change processUMulZExtIdiom to also support adds, since the idiom is the same, except with add instead of mul. Alive2: https://alive2.llvm.org/ce/z/SsB4AK
1 parent d3d4956 commit d765fb9

File tree

4 files changed

+102
-96
lines changed

4 files changed

+102
-96
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 89 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6533,72 +6533,77 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
65336533
llvm_unreachable("Unexpected overflow result");
65346534
}
65356535

6536-
/// Recognize and process idiom involving test for multiplication
6536+
/// Recognize and process idiom involving test for unsigned
65376537
/// overflow.
65386538
///
65396539
/// The caller has matched a pattern of the form:
6540+
/// I = cmp u (add(zext A, zext B), V
65406541
/// I = cmp u (mul(zext A, zext B), V
65416542
/// The function checks if this is a test for overflow and if so replaces
6542-
/// multiplication with call to 'mul.with.overflow' intrinsic.
6543+
/// addition/multiplication with call to the umul intrinsic or the canonical
6544+
/// form of uadd overflow.
65436545
///
65446546
/// \param I Compare instruction.
6545-
/// \param MulVal Result of 'mult' instruction. It is one of the arguments of
6546-
/// the compare instruction. Must be of integer type.
6547+
/// \param Val Result of add/mul instruction. It is one of the arguments of
6548+
/// the compare instruction. Must be of integer type.
65476549
/// \param OtherVal The other argument of compare instruction.
65486550
/// \returns Instruction which must replace the compare instruction, NULL if no
65496551
/// replacement required.
6550-
static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
6551-
const APInt *OtherVal,
6552-
InstCombinerImpl &IC) {
6552+
static Instruction *processUZExtIdiom(ICmpInst &I, Value *Val,
6553+
const APInt *OtherVal,
6554+
InstCombinerImpl &IC) {
65536555
// Don't bother doing this transformation for pointers, don't do it for
65546556
// vectors.
6555-
if (!isa<IntegerType>(MulVal->getType()))
6557+
if (!isa<IntegerType>(Val->getType()))
65566558
return nullptr;
65576559

6558-
auto *MulInstr = dyn_cast<Instruction>(MulVal);
6559-
if (!MulInstr)
6560+
auto *Instr = dyn_cast<Instruction>(Val);
6561+
if (!Instr)
65606562
return nullptr;
6561-
assert(MulInstr->getOpcode() == Instruction::Mul);
65626563

6563-
auto *LHS = cast<ZExtInst>(MulInstr->getOperand(0)),
6564-
*RHS = cast<ZExtInst>(MulInstr->getOperand(1));
6564+
unsigned Opcode = Instr->getOpcode();
6565+
assert(Opcode == Instruction::Add || Opcode == Instruction::Mul);
6566+
6567+
auto *LHS = cast<ZExtInst>(Instr->getOperand(0)),
6568+
*RHS = cast<ZExtInst>(Instr->getOperand(1));
65656569
assert(LHS->getOpcode() == Instruction::ZExt);
65666570
assert(RHS->getOpcode() == Instruction::ZExt);
65676571
Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
65686572

6569-
// Calculate type and width of the result produced by mul.with.overflow.
6573+
// Calculate type and width of the result produced by add/mul.with.overflow.
65706574
Type *TyA = A->getType(), *TyB = B->getType();
65716575
unsigned WidthA = TyA->getPrimitiveSizeInBits(),
65726576
WidthB = TyB->getPrimitiveSizeInBits();
6573-
unsigned MulWidth;
6574-
Type *MulType;
6577+
unsigned ResultWidth;
6578+
Type *ResultType;
65756579
if (WidthB > WidthA) {
6576-
MulWidth = WidthB;
6577-
MulType = TyB;
6580+
ResultWidth = WidthB;
6581+
ResultType = TyB;
65786582
} else {
6579-
MulWidth = WidthA;
6580-
MulType = TyA;
6583+
ResultWidth = WidthA;
6584+
ResultType = TyA;
65816585
}
65826586

6583-
// In order to replace the original mul with a narrower mul.with.overflow,
6584-
// all uses must ignore upper bits of the product. The number of used low
6585-
// bits must be not greater than the width of mul.with.overflow.
6586-
if (MulVal->hasNUsesOrMore(2))
6587-
for (User *U : MulVal->users()) {
6587+
// In order to replace the original result with a narrower
6588+
// add/mul.with.overflow intrinsic, all uses must ignore upper bits of the
6589+
// result. The number of used low bits must be not greater than the width of
6590+
// add or mul.with.overflow.
6591+
if (Val->hasNUsesOrMore(2))
6592+
for (User *U : Val->users()) {
65886593
if (U == &I)
65896594
continue;
65906595
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6591-
// Check if truncation ignores bits above MulWidth.
6596+
// Check if truncation ignores bits above ResultWidth.
65926597
unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
6593-
if (TruncWidth > MulWidth)
6598+
if (TruncWidth > ResultWidth)
65946599
return nullptr;
65956600
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
6596-
// Check if AND ignores bits above MulWidth.
6601+
// Check if AND ignores bits above ResultWidth.
65976602
if (BO->getOpcode() != Instruction::And)
65986603
return nullptr;
65996604
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
66006605
const APInt &CVal = CI->getValue();
6601-
if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth)
6606+
if (CVal.getBitWidth() - CVal.countl_zero() > ResultWidth)
66026607
return nullptr;
66036608
} else {
66046609
// In this case we could have the operand of the binary operation
@@ -6616,9 +6621,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66166621
switch (I.getPredicate()) {
66176622
case ICmpInst::ICMP_UGT: {
66186623
// Recognize pattern:
6619-
// mulval = mul(zext A, zext B)
6620-
// cmp ugt mulval, max
6621-
APInt MaxVal = APInt::getMaxValue(MulWidth);
6624+
// val = add/mul(zext A, zext B)
6625+
// cmp ugt val, max
6626+
APInt MaxVal = APInt::getMaxValue(ResultWidth);
66226627
MaxVal = MaxVal.zext(OtherVal->getBitWidth());
66236628
if (MaxVal.eq(*OtherVal))
66246629
break; // Recognized
@@ -6627,9 +6632,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66276632

66286633
case ICmpInst::ICMP_ULT: {
66296634
// Recognize pattern:
6630-
// mulval = mul(zext A, zext B)
6631-
// cmp ule mulval, max + 1
6632-
APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth);
6635+
// val = add/mul(zext A, zext B)
6636+
// cmp ule val, max + 1
6637+
APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), ResultWidth);
66336638
if (MaxVal.eq(*OtherVal))
66346639
break; // Recognized
66356640
return nullptr;
@@ -6640,38 +6645,57 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66406645
}
66416646

66426647
InstCombiner::BuilderTy &Builder = IC.Builder;
6643-
Builder.SetInsertPoint(MulInstr);
6644-
6645-
// Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
6646-
Value *MulA = A, *MulB = B;
6647-
if (WidthA < MulWidth)
6648-
MulA = Builder.CreateZExt(A, MulType);
6649-
if (WidthB < MulWidth)
6650-
MulB = Builder.CreateZExt(B, MulType);
6651-
CallInst *Call =
6652-
Builder.CreateIntrinsic(Intrinsic::umul_with_overflow, MulType,
6653-
{MulA, MulB}, /*FMFSource=*/nullptr, "umul");
6654-
IC.addToWorklist(MulInstr);
6655-
6656-
// If there are uses of mul result other than the comparison, we know that
6657-
// they are truncation or binary AND. Change them to use result of
6658-
// mul.with.overflow and adjust properly mask/size.
6659-
if (MulVal->hasNUsesOrMore(2)) {
6660-
Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value");
6661-
for (User *U : make_early_inc_range(MulVal->users())) {
6648+
Builder.SetInsertPoint(Instr);
6649+
6650+
// Replace: add/mul(zext A, zext B) --> canonical add/mul + overflow check
6651+
Value *ResultA = A, *ResultB = B;
6652+
if (WidthA < ResultWidth)
6653+
ResultA = Builder.CreateZExt(A, ResultType);
6654+
if (WidthB < ResultWidth)
6655+
ResultB = Builder.CreateZExt(B, ResultType);
6656+
6657+
Value *ArithResult;
6658+
Value *OverflowCheck;
6659+
6660+
if (Opcode == Instruction::Add) {
6661+
// Canonical add overflow check: add + compare
6662+
ArithResult = Builder.CreateAdd(ResultA, ResultB, "add");
6663+
// Overflow if result < either operand (for unsigned add)
6664+
if (I.getPredicate() == ICmpInst::ICMP_ULT)
6665+
OverflowCheck =
6666+
Builder.CreateICmpUGE(ArithResult, ResultA, "not.add.overflow");
6667+
else
6668+
OverflowCheck =
6669+
Builder.CreateICmpULT(ArithResult, ResultA, "add.overflow");
6670+
} else {
6671+
// For multiplication, the intrinsic is actually the canonical form
6672+
CallInst *Call = Builder.CreateIntrinsic(Intrinsic::umul_with_overflow,
6673+
ResultType, {ResultA, ResultB},
6674+
/*FMFSource=*/nullptr, "umul");
6675+
ArithResult = Builder.CreateExtractValue(Call, 0, "umul.value");
6676+
OverflowCheck = Builder.CreateExtractValue(Call, 1, "umul.overflow");
6677+
if (I.getPredicate() == ICmpInst::ICMP_ULT)
6678+
OverflowCheck = Builder.CreateNot(OverflowCheck);
6679+
}
6680+
6681+
IC.addToWorklist(Instr);
6682+
6683+
// Replace uses of the original add/mul result with the new arithmetic result
6684+
if (Val->hasNUsesOrMore(2)) {
6685+
for (User *U : make_early_inc_range(Val->users())) {
66626686
if (U == &I)
66636687
continue;
66646688
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6665-
if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
6666-
IC.replaceInstUsesWith(*TI, Mul);
6689+
if (TI->getType()->getPrimitiveSizeInBits() == ResultWidth)
6690+
IC.replaceInstUsesWith(*TI, ArithResult);
66676691
else
6668-
TI->setOperand(0, Mul);
6692+
TI->setOperand(0, ArithResult);
66696693
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
66706694
assert(BO->getOpcode() == Instruction::And);
6671-
// Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
6695+
// Replace (ArithResult & mask) --> zext (ArithResult & short_mask)
66726696
ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
6673-
APInt ShortMask = CI->getValue().trunc(MulWidth);
6674-
Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
6697+
APInt ShortMask = CI->getValue().trunc(ResultWidth);
6698+
Value *ShortAnd = Builder.CreateAnd(ArithResult, ShortMask);
66756699
Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
66766700
IC.replaceInstUsesWith(*BO, Zext);
66776701
} else {
@@ -6681,14 +6705,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66816705
}
66826706
}
66836707

6684-
// The original icmp gets replaced with the overflow value, maybe inverted
6685-
// depending on predicate.
6686-
if (I.getPredicate() == ICmpInst::ICMP_ULT) {
6687-
Value *Res = Builder.CreateExtractValue(Call, 1);
6688-
return BinaryOperator::CreateNot(Res);
6689-
}
6690-
6691-
return ExtractValueInst::Create(Call, 1);
6708+
return IC.replaceInstUsesWith(I, OverflowCheck);
66926709
}
66936710

66946711
/// When performing a comparison against a constant, it is possible that not all
@@ -7832,10 +7849,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
78327849
}
78337850
}
78347851

7852+
// (zext X) + (zext Y) --> add + overflow check.
78357853
// (zext X) * (zext Y) --> llvm.umul.with.overflow.
7836-
if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
7854+
if ((match(Op0, m_NUWAdd(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) ||
7855+
match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y))))) &&
78377856
match(Op1, m_APInt(C))) {
7838-
if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this))
7857+
if (Instruction *R = processUZExtIdiom(I, Op0, C, *this))
78397858
return R;
78407859
}
78417860

llvm/test/Transforms/InstCombine/overflow-mul.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ define i32 @extra_and_use(i32 %x, i32 %y) {
286286
; CHECK-LABEL: @extra_and_use(
287287
; CHECK-NEXT: [[UMUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
288288
; CHECK-NEXT: [[UMUL_VALUE:%.*]] = extractvalue { i32, i1 } [[UMUL]], 0
289-
; CHECK-NEXT: [[AND:%.*]] = zext i32 [[UMUL_VALUE]] to i64
290289
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1
290+
; CHECK-NEXT: [[AND:%.*]] = zext i32 [[UMUL_VALUE]] to i64
291291
; CHECK-NEXT: call void @use.i64(i64 [[AND]])
292292
; CHECK-NEXT: [[RETVAL:%.*]] = zext i1 [[OVERFLOW]] to i32
293293
; CHECK-NEXT: ret i32 [[RETVAL]]
@@ -306,9 +306,9 @@ define i32 @extra_and_use_small_mask(i32 %x, i32 %y) {
306306
; CHECK-LABEL: @extra_and_use_small_mask(
307307
; CHECK-NEXT: [[UMUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
308308
; CHECK-NEXT: [[UMUL_VALUE:%.*]] = extractvalue { i32, i1 } [[UMUL]], 0
309+
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1
309310
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[UMUL_VALUE]], 268435455
310311
; CHECK-NEXT: [[AND:%.*]] = zext nneg i32 [[TMP1]] to i64
311-
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1
312312
; CHECK-NEXT: call void @use.i64(i64 [[AND]])
313313
; CHECK-NEXT: [[RETVAL:%.*]] = zext i1 [[OVERFLOW]] to i32
314314
; CHECK-NEXT: ret i32 [[RETVAL]]

llvm/test/Transforms/InstCombine/saturating-add-sub.ll

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,11 +2352,7 @@ define i8 @fold_add_umax_to_usub_multiuse(i8 %a) {
23522352

23532353
define i32 @uadd_with_zext(i32 %x, i32 %y) {
23542354
; CHECK-LABEL: @uadd_with_zext(
2355-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
2356-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
2357-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
2358-
; CHECK-NEXT: [[COND1:%.*]] = call i64 @llvm.umin.i64(i64 [[ADD]], i64 4294967295)
2359-
; CHECK-NEXT: [[COND:%.*]] = trunc nuw i64 [[COND1]] to i32
2355+
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
23602356
; CHECK-NEXT: ret i32 [[COND]]
23612357
;
23622358
%conv = zext i32 %x to i64
@@ -2370,13 +2366,9 @@ define i32 @uadd_with_zext(i32 %x, i32 %y) {
23702366

23712367
define i32 @uadd_with_zext_multi_use(i32 %x, i32 %y) {
23722368
; CHECK-LABEL: @uadd_with_zext_multi_use(
2373-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
2374-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
2375-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
2376-
; CHECK-NEXT: [[TRUNCADD:%.*]] = trunc i64 [[ADD]] to i32
2369+
; CHECK-NEXT: [[TRUNCADD:%.*]] = add i32 [[X:%.*]], [[Y:%.*]]
23772370
; CHECK-NEXT: call void @usei32(i32 [[TRUNCADD]])
2378-
; CHECK-NEXT: [[COND1:%.*]] = call i64 @llvm.umin.i64(i64 [[ADD]], i64 4294967295)
2379-
; CHECK-NEXT: [[COND:%.*]] = trunc nuw i64 [[COND1]] to i32
2371+
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Y]])
23802372
; CHECK-NEXT: ret i32 [[COND]]
23812373
;
23822374
%conv = zext i32 %x to i64

llvm/test/Transforms/InstCombine/uadd-with-overflow.ll

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,8 @@ define { <2 x i32>, <2 x i1> } @fold_simple_splat_constant_with_or_fail(<2 x i32
150150

151151
define i32 @uadd_with_zext(i32 %x, i32 %y) {
152152
; CHECK-LABEL: @uadd_with_zext(
153-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
154-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
155-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
156-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ugt i64 [[ADD]], 4294967295
153+
; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[X:%.*]], -1
154+
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[Y:%.*]], [[TMP1]]
157155
; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32
158156
; CHECK-NEXT: ret i32 [[COND]]
159157
;
@@ -167,12 +165,11 @@ define i32 @uadd_with_zext(i32 %x, i32 %y) {
167165

168166
define i32 @uadd_with_zext_use_and(i32 %x, i32 %y) {
169167
; CHECK-LABEL: @uadd_with_zext_use_and(
170-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
171-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
172-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
173-
; CHECK-NEXT: [[AND:%.*]] = and i64 [[ADD]], 65535
168+
; CHECK-NEXT: [[UADD_VALUE:%.*]] = add i32 [[X:%.*]], [[Y:%.*]]
169+
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[UADD_VALUE]], [[X]]
170+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[UADD_VALUE]], 65535
171+
; CHECK-NEXT: [[AND:%.*]] = zext nneg i32 [[TMP1]] to i64
174172
; CHECK-NEXT: call void @usei64(i64 [[AND]])
175-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ugt i64 [[ADD]], 4294967295
176173
; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32
177174
; CHECK-NEXT: ret i32 [[COND]]
178175
;
@@ -188,10 +185,8 @@ define i32 @uadd_with_zext_use_and(i32 %x, i32 %y) {
188185

189186
define i32 @uadd_with_zext_inverse(i32 %x, i32 %y) {
190187
; CHECK-LABEL: @uadd_with_zext_inverse(
191-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
192-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
193-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
194-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ult i64 [[ADD]], 4294967296
188+
; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[X:%.*]], -1
189+
; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[Y:%.*]], [[TMP1]]
195190
; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32
196191
; CHECK-NEXT: ret i32 [[COND]]
197192
;

0 commit comments

Comments
 (0)