diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index c8bdf029dd71c..c7023eb79b04e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -255,6 +255,33 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { } } + // mul (shr exact X, N), (2^N + 1) -> add (X, shr exact (X, N)) + { + Value *NewOp; + const APInt *ShiftC; + const APInt *MulAP; + if (BitWidth > 2 && + match(&I, m_Mul(m_Exact(m_Shr(m_Value(NewOp), m_APInt(ShiftC))), + m_APInt(MulAP))) && + (*MulAP - 1).isPowerOf2() && *ShiftC == MulAP->logBase2()) { + Value *BinOp = Op0; + BinaryOperator *OpBO = cast(Op0); + + // mul nuw (ashr exact X, N) -> add nuw (X, lshr exact (X, N)) + if (HasNUW && OpBO->getOpcode() == Instruction::AShr && OpBO->hasOneUse()) + BinOp = Builder.CreateLShr(NewOp, ConstantInt::get(Ty, *ShiftC), "", + /*isExact=*/true); + + auto *NewAdd = BinaryOperator::CreateAdd(NewOp, BinOp); + if (HasNSW && (HasNUW || OpBO->getOpcode() == Instruction::LShr || + ShiftC->getZExtValue() < BitWidth - 1)) + NewAdd->setHasNoSignedWrap(true); + + NewAdd->setHasNoUnsignedWrap(HasNUW); + return NewAdd; + } + } + if (Op0->hasOneUse() && match(Op1, m_NegatedPower2())) { // Interpret X * (-1<