@@ -6533,72 +6533,77 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
65336533 llvm_unreachable (" Unexpected overflow result" );
65346534}
65356535
6536- // / Recognize and process idiom involving test for multiplication
6536+ // / Recognize and process idiom involving test for unsigned
65376537// / overflow.
65386538// /
65396539// / The caller has matched a pattern of the form:
6540+ // / I = cmp u (add(zext A, zext B), V
65406541// / I = cmp u (mul(zext A, zext B), V
65416542// / The function checks if this is a test for overflow and if so replaces
6542- // / multiplication with call to 'mul.with.overflow' intrinsic.
6543+ // / addition/multiplication with call to the umul intrinsic or the canonical
6544+ // / form of uadd overflow.
65436545// /
65446546// / \param I Compare instruction.
6545- // / \param MulVal Result of 'mult' instruction. It is one of the arguments of
6546- // / the compare instruction. Must be of integer type.
6547+ // / \param Val Result of add/mul instruction. It is one of the arguments of
6548+ // / the compare instruction. Must be of integer type.
65476549// / \param OtherVal The other argument of compare instruction.
65486550// / \returns Instruction which must replace the compare instruction, NULL if no
65496551// / replacement required.
6550- static Instruction *processUMulZExtIdiom (ICmpInst &I, Value *MulVal ,
6551- const APInt *OtherVal,
6552- InstCombinerImpl &IC) {
6552+ static Instruction *processUZExtIdiom (ICmpInst &I, Value *Val ,
6553+ const APInt *OtherVal,
6554+ InstCombinerImpl &IC) {
65536555 // Don't bother doing this transformation for pointers, don't do it for
65546556 // vectors.
6555- if (!isa<IntegerType>(MulVal ->getType ()))
6557+ if (!isa<IntegerType>(Val ->getType ()))
65566558 return nullptr ;
65576559
6558- auto *MulInstr = dyn_cast<Instruction>(MulVal );
6559- if (!MulInstr )
6560+ auto *Instr = dyn_cast<Instruction>(Val );
6561+ if (!Instr )
65606562 return nullptr ;
6561- assert (MulInstr->getOpcode () == Instruction::Mul);
65626563
6563- auto *LHS = cast<ZExtInst>(MulInstr->getOperand (0 )),
6564- *RHS = cast<ZExtInst>(MulInstr->getOperand (1 ));
6564+ unsigned Opcode = Instr->getOpcode ();
6565+ assert (Opcode == Instruction::Add || Opcode == Instruction::Mul);
6566+
6567+ auto *LHS = cast<ZExtInst>(Instr->getOperand (0 )),
6568+ *RHS = cast<ZExtInst>(Instr->getOperand (1 ));
65656569 assert (LHS->getOpcode () == Instruction::ZExt);
65666570 assert (RHS->getOpcode () == Instruction::ZExt);
65676571 Value *A = LHS->getOperand (0 ), *B = RHS->getOperand (0 );
65686572
6569- // Calculate type and width of the result produced by mul.with.overflow.
6573+ // Calculate type and width of the result produced by add/ mul.with.overflow.
65706574 Type *TyA = A->getType (), *TyB = B->getType ();
65716575 unsigned WidthA = TyA->getPrimitiveSizeInBits (),
65726576 WidthB = TyB->getPrimitiveSizeInBits ();
6573- unsigned MulWidth ;
6574- Type *MulType ;
6577+ unsigned ResultWidth ;
6578+ Type *ResultType ;
65756579 if (WidthB > WidthA) {
6576- MulWidth = WidthB;
6577- MulType = TyB;
6580+ ResultWidth = WidthB;
6581+ ResultType = TyB;
65786582 } else {
6579- MulWidth = WidthA;
6580- MulType = TyA;
6583+ ResultWidth = WidthA;
6584+ ResultType = TyA;
65816585 }
65826586
6583- // In order to replace the original mul with a narrower mul.with.overflow,
6584- // all uses must ignore upper bits of the product. The number of used low
6585- // bits must be not greater than the width of mul.with.overflow.
6586- if (MulVal->hasNUsesOrMore (2 ))
6587- for (User *U : MulVal->users ()) {
6587+ // In order to replace the original result with a narrower
6588+ // add/mul.with.overflow intrinsic, all uses must ignore upper bits of the
6589+ // result. The number of used low bits must be not greater than the width of
6590+ // add or mul.with.overflow.
6591+ if (Val->hasNUsesOrMore (2 ))
6592+ for (User *U : Val->users ()) {
65886593 if (U == &I)
65896594 continue ;
65906595 if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6591- // Check if truncation ignores bits above MulWidth .
6596+ // Check if truncation ignores bits above ResultWidth .
65926597 unsigned TruncWidth = TI->getType ()->getPrimitiveSizeInBits ();
6593- if (TruncWidth > MulWidth )
6598+ if (TruncWidth > ResultWidth )
65946599 return nullptr ;
65956600 } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
6596- // Check if AND ignores bits above MulWidth .
6601+ // Check if AND ignores bits above ResultWidth .
65976602 if (BO->getOpcode () != Instruction::And)
65986603 return nullptr ;
65996604 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand (1 ))) {
66006605 const APInt &CVal = CI->getValue ();
6601- if (CVal.getBitWidth () - CVal.countl_zero () > MulWidth )
6606+ if (CVal.getBitWidth () - CVal.countl_zero () > ResultWidth )
66026607 return nullptr ;
66036608 } else {
66046609 // In this case we could have the operand of the binary operation
@@ -6616,9 +6621,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66166621 switch (I.getPredicate ()) {
66176622 case ICmpInst::ICMP_UGT: {
66186623 // Recognize pattern:
6619- // mulval = mul(zext A, zext B)
6620- // cmp ugt mulval , max
6621- APInt MaxVal = APInt::getMaxValue (MulWidth );
6624+ // val = add/ mul(zext A, zext B)
6625+ // cmp ugt val , max
6626+ APInt MaxVal = APInt::getMaxValue (ResultWidth );
66226627 MaxVal = MaxVal.zext (OtherVal->getBitWidth ());
66236628 if (MaxVal.eq (*OtherVal))
66246629 break ; // Recognized
@@ -6627,9 +6632,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66276632
66286633 case ICmpInst::ICMP_ULT: {
66296634 // Recognize pattern:
6630- // mulval = mul(zext A, zext B)
6631- // cmp ule mulval , max + 1
6632- APInt MaxVal = APInt::getOneBitSet (OtherVal->getBitWidth (), MulWidth );
6635+ // val = add/ mul(zext A, zext B)
6636+ // cmp ule val , max + 1
6637+ APInt MaxVal = APInt::getOneBitSet (OtherVal->getBitWidth (), ResultWidth );
66336638 if (MaxVal.eq (*OtherVal))
66346639 break ; // Recognized
66356640 return nullptr ;
@@ -6640,38 +6645,57 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66406645 }
66416646
66426647 InstCombiner::BuilderTy &Builder = IC.Builder ;
6643- Builder.SetInsertPoint (MulInstr);
6644-
6645- // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
6646- Value *MulA = A, *MulB = B;
6647- if (WidthA < MulWidth)
6648- MulA = Builder.CreateZExt (A, MulType);
6649- if (WidthB < MulWidth)
6650- MulB = Builder.CreateZExt (B, MulType);
6651- CallInst *Call =
6652- Builder.CreateIntrinsic (Intrinsic::umul_with_overflow, MulType,
6653- {MulA, MulB}, /* FMFSource=*/ nullptr , " umul" );
6654- IC.addToWorklist (MulInstr);
6655-
6656- // If there are uses of mul result other than the comparison, we know that
6657- // they are truncation or binary AND. Change them to use result of
6658- // mul.with.overflow and adjust properly mask/size.
6659- if (MulVal->hasNUsesOrMore (2 )) {
6660- Value *Mul = Builder.CreateExtractValue (Call, 0 , " umul.value" );
6661- for (User *U : make_early_inc_range (MulVal->users ())) {
6648+ Builder.SetInsertPoint (Instr);
6649+
6650+ // Replace: add/mul(zext A, zext B) --> canonical add/mul + overflow check
6651+ Value *ResultA = A, *ResultB = B;
6652+ if (WidthA < ResultWidth)
6653+ ResultA = Builder.CreateZExt (A, ResultType);
6654+ if (WidthB < ResultWidth)
6655+ ResultB = Builder.CreateZExt (B, ResultType);
6656+
6657+ Value *ArithResult;
6658+ Value *OverflowCheck;
6659+
6660+ if (Opcode == Instruction::Add) {
6661+ // Canonical add overflow check: add + compare
6662+ ArithResult = Builder.CreateAdd (ResultA, ResultB, " add" );
6663+ // Overflow if result < either operand (for unsigned add)
6664+ if (I.getPredicate () == ICmpInst::ICMP_ULT)
6665+ OverflowCheck =
6666+ Builder.CreateICmpUGE (ArithResult, ResultA, " not.add.overflow" );
6667+ else
6668+ OverflowCheck =
6669+ Builder.CreateICmpULT (ArithResult, ResultA, " add.overflow" );
6670+ } else {
6671+ // For multiplication, the intrinsic is actually the canonical form
6672+ CallInst *Call = Builder.CreateIntrinsic (Intrinsic::umul_with_overflow,
6673+ ResultType, {ResultA, ResultB},
6674+ /* FMFSource=*/ nullptr , " umul" );
6675+ ArithResult = Builder.CreateExtractValue (Call, 0 , " umul.value" );
6676+ OverflowCheck = Builder.CreateExtractValue (Call, 1 , " umul.overflow" );
6677+ if (I.getPredicate () == ICmpInst::ICMP_ULT)
6678+ OverflowCheck = Builder.CreateNot (OverflowCheck);
6679+ }
6680+
6681+ IC.addToWorklist (Instr);
6682+
6683+ // Replace uses of the original add/mul result with the new arithmetic result
6684+ if (Val->hasNUsesOrMore (2 )) {
6685+ for (User *U : make_early_inc_range (Val->users ())) {
66626686 if (U == &I)
66636687 continue ;
66646688 if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6665- if (TI->getType ()->getPrimitiveSizeInBits () == MulWidth )
6666- IC.replaceInstUsesWith (*TI, Mul );
6689+ if (TI->getType ()->getPrimitiveSizeInBits () == ResultWidth )
6690+ IC.replaceInstUsesWith (*TI, ArithResult );
66676691 else
6668- TI->setOperand (0 , Mul );
6692+ TI->setOperand (0 , ArithResult );
66696693 } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
66706694 assert (BO->getOpcode () == Instruction::And);
6671- // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
6695+ // Replace (ArithResult & mask) --> zext (ArithResult & short_mask)
66726696 ConstantInt *CI = cast<ConstantInt>(BO->getOperand (1 ));
6673- APInt ShortMask = CI->getValue ().trunc (MulWidth );
6674- Value *ShortAnd = Builder.CreateAnd (Mul , ShortMask);
6697+ APInt ShortMask = CI->getValue ().trunc (ResultWidth );
6698+ Value *ShortAnd = Builder.CreateAnd (ArithResult , ShortMask);
66756699 Value *Zext = Builder.CreateZExt (ShortAnd, BO->getType ());
66766700 IC.replaceInstUsesWith (*BO, Zext);
66776701 } else {
@@ -6681,14 +6705,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66816705 }
66826706 }
66836707
6684- // The original icmp gets replaced with the overflow value, maybe inverted
6685- // depending on predicate.
6686- if (I.getPredicate () == ICmpInst::ICMP_ULT) {
6687- Value *Res = Builder.CreateExtractValue (Call, 1 );
6688- return BinaryOperator::CreateNot (Res);
6689- }
6690-
6691- return ExtractValueInst::Create (Call, 1 );
6708+ return IC.replaceInstUsesWith (I, OverflowCheck);
66926709}
66936710
66946711// / When performing a comparison against a constant, it is possible that not all
@@ -7832,10 +7849,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
78327849 }
78337850 }
78347851
7852+ // (zext X) + (zext Y) --> add + overflow check.
78357853 // (zext X) * (zext Y) --> llvm.umul.with.overflow.
7836- if (match (Op0, m_NUWMul (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y)))) &&
7854+ if ((match (Op0, m_NUWAdd (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y)))) ||
7855+ match (Op0, m_NUWMul (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y))))) &&
78377856 match (Op1, m_APInt (C))) {
7838- if (Instruction *R = processUMulZExtIdiom (I, Op0, C, *this ))
7857+ if (Instruction *R = processUZExtIdiom (I, Op0, C, *this ))
78397858 return R;
78407859 }
78417860
0 commit comments