diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 29c5cef84ccdb..382078e85a17b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1771,26 +1771,38 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI, return Builder.CreateBinaryIntrinsic(Intrinsic::smin, V, TVal); } - BinaryOperator *BO; + // Fold icmp(X) ? f(X) : C to f(X) when f(X) is guaranteed to be equal to C + // for all X in the exact range of the inverse predicate. + Instruction *Op; const APInt *C; CmpInst::Predicate CPred; - if (match(&SI, m_Select(m_Specific(ICI), m_APInt(C), m_BinOp(BO)))) + if (match(&SI, m_Select(m_Specific(ICI), m_APInt(C), m_Instruction(Op)))) CPred = ICI->getPredicate(); - else if (match(&SI, m_Select(m_Specific(ICI), m_BinOp(BO), m_APInt(C)))) + else if (match(&SI, m_Select(m_Specific(ICI), m_Instruction(Op), m_APInt(C)))) CPred = ICI->getInversePredicate(); else return nullptr; - const APInt *BinOpC; - if (!match(BO, m_BinOp(m_Specific(V), m_APInt(BinOpC)))) - return nullptr; - - ConstantRange R = ConstantRange::makeExactICmpRegion(CPred, *CmpC) - .binaryOp(BO->getOpcode(), *BinOpC); - if (R == *C) { - BO->dropPoisonGeneratingFlags(); - return BO; + ConstantRange InvDomCR = ConstantRange::makeExactICmpRegion(CPred, *CmpC); + const APInt *OpC; + if (match(Op, m_BinOp(m_Specific(V), m_APInt(OpC)))) { + ConstantRange R = InvDomCR.binaryOp( + static_cast(Op->getOpcode()), *OpC); + if (R == *C) { + Op->dropPoisonGeneratingFlags(); + return Op; + } + } + if (auto *MMI = dyn_cast(Op); + MMI && MMI->getLHS() == V && match(MMI->getRHS(), m_APInt(OpC))) { + ConstantRange R = ConstantRange::intrinsic(MMI->getIntrinsicID(), + {InvDomCR, ConstantRange(*OpC)}); + if (R == *C) { + MMI->dropPoisonGeneratingAnnotations(); + return MMI; + } } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/select-min-max.ll b/llvm/test/Transforms/InstCombine/select-min-max.ll index f13223284e6a6..0430fcd5ad370 100644 --- a/llvm/test/Transforms/InstCombine/select-min-max.ll +++ b/llvm/test/Transforms/InstCombine/select-min-max.ll @@ -301,3 +301,66 @@ define i8 @not_smin_swap(i8 %i41, i8 %i43) { %spec.select = select i1 %i44, i8 %i46, i8 0 ret i8 %spec.select } + +define i8 @sel_umin_constant(i8 %x) { +; CHECK-LABEL: @sel_umin_constant( +; CHECK-NEXT: [[UMIN:%.*]] = call i8 @llvm.umin.i8(i8 [[X:%.*]], i8 16) +; CHECK-NEXT: ret i8 [[UMIN]] +; + %cmp = icmp sgt i8 %x, -1 + %umin = call i8 @llvm.umin.i8(i8 %x, i8 16) + %sel = select i1 %cmp, i8 %umin, i8 16 + ret i8 %sel +} + +define i8 @sel_constant_smax_with_range_attr(i8 %x) { +; CHECK-LABEL: @sel_constant_smax_with_range_attr( +; CHECK-NEXT: [[SEL:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 16) +; CHECK-NEXT: ret i8 [[SEL]] +; + %cmp = icmp slt i8 %x, 0 + %smax = call range(i8 8, 16) i8 @llvm.smax.i8(i8 %x, i8 16) + %sel = select i1 %cmp, i8 16, i8 %smax + ret i8 %sel +} + +; Negative tests + +define i8 @sel_umin_constant_mismatch(i8 %x) { +; CHECK-LABEL: @sel_umin_constant_mismatch( +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], -1 +; CHECK-NEXT: [[UMIN:%.*]] = call i8 @llvm.umin.i8(i8 [[X]], i8 16) +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[UMIN]], i8 15 +; CHECK-NEXT: ret i8 [[SEL]] +; + %cmp = icmp sgt i8 %x, -1 + %umin = call i8 @llvm.umin.i8(i8 %x, i8 16) + %sel = select i1 %cmp, i8 %umin, i8 15 + ret i8 %sel +} + +define i8 @sel_umin_constant_op_mismatch(i8 %x, i8 %y) { +; CHECK-LABEL: @sel_umin_constant_op_mismatch( +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], -1 +; CHECK-NEXT: [[UMIN:%.*]] = call i8 @llvm.umin.i8(i8 [[Y:%.*]], i8 16) +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[UMIN]], i8 16 +; CHECK-NEXT: ret i8 [[SEL]] +; + %cmp = icmp sgt i8 %x, -1 + %umin = call i8 @llvm.umin.i8(i8 %y, i8 16) + %sel = select i1 %cmp, i8 %umin, i8 16 + ret i8 %sel +} + +define i8 @sel_umin_non_constant(i8 %x, i8 %y) { +; CHECK-LABEL: @sel_umin_non_constant( +; CHECK-NEXT: [[UMIN:%.*]] = call i8 @llvm.umin.i8(i8 [[X:%.*]], i8 16) +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[X]], 0 +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP1]], i8 [[Y:%.*]], i8 [[UMIN]] +; CHECK-NEXT: ret i8 [[SEL]] +; + %cmp = icmp sgt i8 %x, -1 + %umin = call i8 @llvm.umin.i8(i8 %x, i8 16) + %sel = select i1 %cmp, i8 %umin, i8 %y + ret i8 %sel +}