diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 078825f2a9a22..12c435ad1bc85 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4940,17 +4940,18 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, // If we have a clamp pattern, we know that the number of sign bits will be // the minimum of the clamp min/max range. bool IsMax = (Opcode == ISD::SMAX); - ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr; - if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts))) + KnownBits KnownLow, KnownHigh; + KnownLow = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + if (KnownLow.isConstant()) if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX)) - CstHigh = - isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts); - if (CstLow && CstHigh) { + KnownHigh = computeKnownBits(Op.getOperand(0).getOperand(1), + DemandedElts, Depth + 2); + if (KnownLow.isConstant() && KnownHigh.isConstant()) { if (!IsMax) - std::swap(CstLow, CstHigh); - if (CstLow->getAPIntValue().sle(CstHigh->getAPIntValue())) { - Tmp = CstLow->getAPIntValue().getNumSignBits(); - Tmp2 = CstHigh->getAPIntValue().getNumSignBits(); + std::swap(KnownLow, KnownHigh); + if (KnownLow.getConstant().sle(KnownHigh.getConstant())) { + Tmp = KnownLow.getConstant().getNumSignBits(); + Tmp2 = KnownHigh.getConstant().getNumSignBits(); return std::min(Tmp, Tmp2); } }