diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index c6f317a668cfe..f132daafa0873 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -6533,72 +6533,76 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, llvm_unreachable("Unexpected overflow result"); } -/// Recognize and process idiom involving test for multiplication +/// Recognize and process idiom involving test for unsigned /// overflow. /// /// The caller has matched a pattern of the form: +/// I = cmp u (add(zext A, zext B), V /// I = cmp u (mul(zext A, zext B), V /// The function checks if this is a test for overflow and if so replaces -/// multiplication with call to 'mul.with.overflow' intrinsic. +/// addition/multiplication with call to the umul intrinsic or the canonical +/// form of uadd overflow. /// /// \param I Compare instruction. -/// \param MulVal Result of 'mult' instruction. It is one of the arguments of -/// the compare instruction. Must be of integer type. +/// \param Val Result of add/mul instruction. It is one of the arguments of +/// the compare instruction. Must be of integer type. /// \param OtherVal The other argument of compare instruction. /// \returns Instruction which must replace the compare instruction, NULL if no /// replacement required. -static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, - const APInt *OtherVal, - InstCombinerImpl &IC) { +static Instruction *processUZExtIdiom(ICmpInst &I, Value *Val, + const APInt *OtherVal, + InstCombinerImpl &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. - if (!isa(MulVal->getType())) + if (!isa(Val->getType())) return nullptr; - auto *MulInstr = dyn_cast(MulVal); - if (!MulInstr) + auto *Instr = dyn_cast(Val); + if (!Instr) return nullptr; - assert(MulInstr->getOpcode() == Instruction::Mul); - auto *LHS = cast(MulInstr->getOperand(0)), - *RHS = cast(MulInstr->getOperand(1)); + unsigned Opcode = Instr->getOpcode(); + assert(Opcode == Instruction::Add || Opcode == Instruction::Mul); + + auto *LHS = cast(Instr->getOperand(0)), + *RHS = cast(Instr->getOperand(1)); assert(LHS->getOpcode() == Instruction::ZExt); assert(RHS->getOpcode() == Instruction::ZExt); Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); - // Calculate type and width of the result produced by mul.with.overflow. + // Calculate type and width of the result produced by add/mul.with.overflow. Type *TyA = A->getType(), *TyB = B->getType(); unsigned WidthA = TyA->getPrimitiveSizeInBits(), WidthB = TyB->getPrimitiveSizeInBits(); - unsigned MulWidth; - Type *MulType; + unsigned ResultWidth; + Type *ResultType; if (WidthB > WidthA) { - MulWidth = WidthB; - MulType = TyB; + ResultWidth = WidthB; + ResultType = TyB; } else { - MulWidth = WidthA; - MulType = TyA; + ResultWidth = WidthA; + ResultType = TyA; } - // In order to replace the original mul with a narrower mul.with.overflow, - // all uses must ignore upper bits of the product. The number of used low - // bits must be not greater than the width of mul.with.overflow. - if (MulVal->hasNUsesOrMore(2)) - for (User *U : MulVal->users()) { + // In order to replace the original result with a narrower one, all uses must + // ignore upper bits of the result. The number of used low bits must be not + // greater than the width of add or mul.with.overflow. + if (Val->hasNUsesOrMore(2)) + for (User *U : Val->users()) { if (U == &I) continue; if (TruncInst *TI = dyn_cast(U)) { - // Check if truncation ignores bits above MulWidth. + // Check if truncation ignores bits above ResultWidth. unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits(); - if (TruncWidth > MulWidth) + if (TruncWidth > ResultWidth) return nullptr; } else if (BinaryOperator *BO = dyn_cast(U)) { - // Check if AND ignores bits above MulWidth. + // Check if AND ignores bits above ResultWidth. if (BO->getOpcode() != Instruction::And) return nullptr; if (ConstantInt *CI = dyn_cast(BO->getOperand(1))) { const APInt &CVal = CI->getValue(); - if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth) + if (CVal.getBitWidth() - CVal.countl_zero() > ResultWidth) return nullptr; } else { // In this case we could have the operand of the binary operation @@ -6616,9 +6620,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, switch (I.getPredicate()) { case ICmpInst::ICMP_UGT: { // Recognize pattern: - // mulval = mul(zext A, zext B) - // cmp ugt mulval, max - APInt MaxVal = APInt::getMaxValue(MulWidth); + // val = add/mul(zext A, zext B) + // cmp ugt val, max + APInt MaxVal = APInt::getMaxValue(ResultWidth); MaxVal = MaxVal.zext(OtherVal->getBitWidth()); if (MaxVal.eq(*OtherVal)) break; // Recognized @@ -6627,9 +6631,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, case ICmpInst::ICMP_ULT: { // Recognize pattern: - // mulval = mul(zext A, zext B) - // cmp ule mulval, max + 1 - APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth); + // val = add/mul(zext A, zext B) + // cmp ule val, max + 1 + APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), ResultWidth); if (MaxVal.eq(*OtherVal)) break; // Recognized return nullptr; @@ -6640,38 +6644,57 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, } InstCombiner::BuilderTy &Builder = IC.Builder; - Builder.SetInsertPoint(MulInstr); - - // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B) - Value *MulA = A, *MulB = B; - if (WidthA < MulWidth) - MulA = Builder.CreateZExt(A, MulType); - if (WidthB < MulWidth) - MulB = Builder.CreateZExt(B, MulType); - CallInst *Call = - Builder.CreateIntrinsic(Intrinsic::umul_with_overflow, MulType, - {MulA, MulB}, /*FMFSource=*/nullptr, "umul"); - IC.addToWorklist(MulInstr); - - // If there are uses of mul result other than the comparison, we know that - // they are truncation or binary AND. Change them to use result of - // mul.with.overflow and adjust properly mask/size. - if (MulVal->hasNUsesOrMore(2)) { - Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); - for (User *U : make_early_inc_range(MulVal->users())) { + Builder.SetInsertPoint(Instr); + + // Replace: add/mul(zext A, zext B) --> canonical add/mul + overflow check + Value *ResultA = A, *ResultB = B; + if (WidthA < ResultWidth) + ResultA = Builder.CreateZExt(A, ResultType); + if (WidthB < ResultWidth) + ResultB = Builder.CreateZExt(B, ResultType); + + Value *ArithResult; + Value *OverflowCheck; + + if (Opcode == Instruction::Add) { + // Canonical add overflow check: add + compare + ArithResult = Builder.CreateAdd(ResultA, ResultB, "add"); + // Overflow if result < either operand (for unsigned add) + if (I.getPredicate() == ICmpInst::ICMP_ULT) + OverflowCheck = + Builder.CreateICmpUGE(ArithResult, ResultA, "not.add.overflow"); + else + OverflowCheck = + Builder.CreateICmpULT(ArithResult, ResultA, "add.overflow"); + } else { + // For multiplication, the intrinsic is actually the canonical form + CallInst *Call = Builder.CreateIntrinsic(Intrinsic::umul_with_overflow, + ResultType, {ResultA, ResultB}, + /*FMFSource=*/nullptr, "umul"); + ArithResult = Builder.CreateExtractValue(Call, 0, "umul.value"); + OverflowCheck = Builder.CreateExtractValue(Call, 1, "umul.overflow"); + if (I.getPredicate() == ICmpInst::ICMP_ULT) + OverflowCheck = Builder.CreateNot(OverflowCheck); + } + + IC.addToWorklist(Instr); + + // Replace uses of the original add/mul result with the new arithmetic result + if (Val->hasNUsesOrMore(2)) { + for (User *U : make_early_inc_range(Val->users())) { if (U == &I) continue; if (TruncInst *TI = dyn_cast(U)) { - if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) - IC.replaceInstUsesWith(*TI, Mul); + if (TI->getType()->getPrimitiveSizeInBits() == ResultWidth) + IC.replaceInstUsesWith(*TI, ArithResult); else - TI->setOperand(0, Mul); + TI->setOperand(0, ArithResult); } else if (BinaryOperator *BO = dyn_cast(U)) { assert(BO->getOpcode() == Instruction::And); - // Replace (mul & mask) --> zext (mul.with.overflow & short_mask) + // Replace (ArithResult & mask) --> zext (ArithResult & short_mask) ConstantInt *CI = cast(BO->getOperand(1)); - APInt ShortMask = CI->getValue().trunc(MulWidth); - Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask); + APInt ShortMask = CI->getValue().trunc(ResultWidth); + Value *ShortAnd = Builder.CreateAnd(ArithResult, ShortMask); Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType()); IC.replaceInstUsesWith(*BO, Zext); } else { @@ -6681,14 +6704,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, } } - // The original icmp gets replaced with the overflow value, maybe inverted - // depending on predicate. - if (I.getPredicate() == ICmpInst::ICMP_ULT) { - Value *Res = Builder.CreateExtractValue(Call, 1); - return BinaryOperator::CreateNot(Res); - } - - return ExtractValueInst::Create(Call, 1); + return IC.replaceInstUsesWith(I, OverflowCheck); } /// When performing a comparison against a constant, it is possible that not all @@ -7832,10 +7848,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } } + // (zext X) + (zext Y) --> add + overflow check. // (zext X) * (zext Y) --> llvm.umul.with.overflow. - if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) && + if ((match(Op0, m_NUWAdd(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) || + match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y))))) && match(Op1, m_APInt(C))) { - if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this)) + if (Instruction *R = processUZExtIdiom(I, Op0, C, *this)) return R; } diff --git a/llvm/test/Transforms/InstCombine/overflow-mul.ll b/llvm/test/Transforms/InstCombine/overflow-mul.ll index 1d18d9ffd46d2..4aa59ac755e8d 100644 --- a/llvm/test/Transforms/InstCombine/overflow-mul.ll +++ b/llvm/test/Transforms/InstCombine/overflow-mul.ll @@ -286,8 +286,8 @@ define i32 @extra_and_use(i32 %x, i32 %y) { ; CHECK-LABEL: @extra_and_use( ; CHECK-NEXT: [[UMUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) ; CHECK-NEXT: [[UMUL_VALUE:%.*]] = extractvalue { i32, i1 } [[UMUL]], 0 -; CHECK-NEXT: [[AND:%.*]] = zext i32 [[UMUL_VALUE]] to i64 ; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1 +; CHECK-NEXT: [[AND:%.*]] = zext i32 [[UMUL_VALUE]] to i64 ; CHECK-NEXT: call void @use.i64(i64 [[AND]]) ; CHECK-NEXT: [[RETVAL:%.*]] = zext i1 [[OVERFLOW]] to i32 ; CHECK-NEXT: ret i32 [[RETVAL]] @@ -306,9 +306,9 @@ define i32 @extra_and_use_small_mask(i32 %x, i32 %y) { ; CHECK-LABEL: @extra_and_use_small_mask( ; CHECK-NEXT: [[UMUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) ; CHECK-NEXT: [[UMUL_VALUE:%.*]] = extractvalue { i32, i1 } [[UMUL]], 0 +; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1 ; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[UMUL_VALUE]], 268435455 ; CHECK-NEXT: [[AND:%.*]] = zext nneg i32 [[TMP1]] to i64 -; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1 ; CHECK-NEXT: call void @use.i64(i64 [[AND]]) ; CHECK-NEXT: [[RETVAL:%.*]] = zext i1 [[OVERFLOW]] to i32 ; CHECK-NEXT: ret i32 [[RETVAL]] diff --git a/llvm/test/Transforms/InstCombine/saturating-add-sub.ll b/llvm/test/Transforms/InstCombine/saturating-add-sub.ll index cfd679c0cc592..d19515c638c81 100644 --- a/llvm/test/Transforms/InstCombine/saturating-add-sub.ll +++ b/llvm/test/Transforms/InstCombine/saturating-add-sub.ll @@ -2350,4 +2350,57 @@ define i8 @fold_add_umax_to_usub_multiuse(i8 %a) { ret i8 %sel } +define i32 @uadd_with_zext(i32 %x, i32 %y) { +; CHECK-LABEL: @uadd_with_zext( +; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) +; CHECK-NEXT: ret i32 [[COND]] +; + %conv = zext i32 %x to i64 + %conv1 = zext i32 %y to i64 + %add = add i64 %conv, %conv1 + %cmp = icmp ugt i64 %add, 4294967295 + %conv4 = trunc i64 %add to i32 + %cond = select i1 %cmp, i32 -1, i32 %conv4 + ret i32 %cond +} + +define i32 @uadd_with_zext_multi_use(i32 %x, i32 %y) { +; CHECK-LABEL: @uadd_with_zext_multi_use( +; CHECK-NEXT: [[TRUNCADD:%.*]] = add i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: call void @usei32(i32 [[TRUNCADD]]) +; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Y]]) +; CHECK-NEXT: ret i32 [[COND]] +; + %conv = zext i32 %x to i64 + %conv1 = zext i32 %y to i64 + %add = add i64 %conv, %conv1 + %truncAdd = trunc i64 %add to i32 + call void @usei32(i32 %truncAdd) + %cmp = icmp ugt i64 %add, 4294967295 + %cond = select i1 %cmp, i32 -1, i32 %truncAdd + ret i32 %cond +} + +define i32 @uadd_with_zext_neg_use(i32 %x, i32 %y) { +; CHECK-LABEL: @uadd_with_zext_neg_use( +; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64 +; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64 +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]] +; CHECK-NEXT: call void @usei64(i64 [[ADD]]) +; CHECK-NEXT: [[COND1:%.*]] = call i64 @llvm.umin.i64(i64 [[ADD]], i64 4294967295) +; CHECK-NEXT: [[COND:%.*]] = trunc nuw i64 [[COND1]] to i32 +; CHECK-NEXT: ret i32 [[COND]] +; + %conv = zext i32 %x to i64 + %conv1 = zext i32 %y to i64 + %add = add i64 %conv, %conv1 + call void @usei64(i64 %add) + %cmp = icmp ugt i64 %add, 4294967295 + %conv4 = trunc i64 %add to i32 + %cond = select i1 %cmp, i32 -1, i32 %conv4 + ret i32 %cond +} + +declare void @usei64(i64) +declare void @usei32(i32) declare void @usei8(i8) diff --git a/llvm/test/Transforms/InstCombine/uadd-with-overflow.ll b/llvm/test/Transforms/InstCombine/uadd-with-overflow.ll index eb021a0fd2c89..a2d709ee17821 100644 --- a/llvm/test/Transforms/InstCombine/uadd-with-overflow.ll +++ b/llvm/test/Transforms/InstCombine/uadd-with-overflow.ll @@ -147,3 +147,74 @@ define { <2 x i32>, <2 x i1> } @fold_simple_splat_constant_with_or_fail(<2 x i32 %b = tail call { <2 x i32>, <2 x i1> } @llvm.uadd.with.overflow.v2i32(<2 x i32> %a, <2 x i32> ) ret { <2 x i32>, <2 x i1> } %b } + +define i32 @uadd_with_zext(i32 %x, i32 %y) { +; CHECK-LABEL: @uadd_with_zext( +; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[X:%.*]], -1 +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[Y:%.*]], [[TMP1]] +; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32 +; CHECK-NEXT: ret i32 [[COND]] +; + %conv = zext i32 %x to i64 + %conv1 = zext i32 %y to i64 + %add = add i64 %conv, %conv1 + %cmp = icmp ugt i64 %add, 4294967295 + %cond = zext i1 %cmp to i32 + ret i32 %cond +} + +define i32 @uadd_with_zext_use_and(i32 %x, i32 %y) { +; CHECK-LABEL: @uadd_with_zext_use_and( +; CHECK-NEXT: [[UADD_VALUE:%.*]] = add i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[UADD_VALUE]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[UADD_VALUE]], 65535 +; CHECK-NEXT: [[AND:%.*]] = zext nneg i32 [[TMP1]] to i64 +; CHECK-NEXT: call void @usei64(i64 [[AND]]) +; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32 +; CHECK-NEXT: ret i32 [[COND]] +; + %conv = zext i32 %x to i64 + %conv1 = zext i32 %y to i64 + %add = add i64 %conv, %conv1 + %and = and i64 %add, 65535 + call void @usei64(i64 %and) + %cmp = icmp ugt i64 %add, 4294967295 + %cond = zext i1 %cmp to i32 + ret i32 %cond +} + +define i32 @uadd_with_zext_inverse(i32 %x, i32 %y) { +; CHECK-LABEL: @uadd_with_zext_inverse( +; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[X:%.*]], -1 +; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[Y:%.*]], [[TMP1]] +; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32 +; CHECK-NEXT: ret i32 [[COND]] +; + %conv = zext i32 %x to i64 + %conv1 = zext i32 %y to i64 + %add = add i64 %conv, %conv1 + %cmp = icmp ule i64 %add, 4294967295 + %cond = zext i1 %cmp to i32 + ret i32 %cond +} + +define i32 @uadd_with_zext_neg_use(i32 %x, i32 %y) { +; CHECK-LABEL: @uadd_with_zext_neg_use( +; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64 +; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64 +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]] +; CHECK-NEXT: call void @usei64(i64 [[ADD]]) +; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ugt i64 [[ADD]], 4294967295 +; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32 +; CHECK-NEXT: ret i32 [[COND]] +; + %conv = zext i32 %x to i64 + %conv1 = zext i32 %y to i64 + %add = add i64 %conv, %conv1 + call void @usei64(i64 %add) + %cmp = icmp ugt i64 %add, 4294967295 + %cond = zext i1 %cmp to i32 + ret i32 %cond +} + +declare void @usei64(i64)