diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 5b9338b500ecc..b7084cf3619f3 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -9165,15 +9165,13 @@ isImpliedCondCommonOperandWithCR(CmpPredicate LPred, const ConstantRange &LCR, /// is true. Return false if LHS implies RHS is false. Otherwise, return /// std::nullopt if we can't infer anything. static std::optional -isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0, - const Value *R1, const DataLayout &DL, bool LHSIsTrue) { - Value *L0 = LHS->getOperand(0); - Value *L1 = LHS->getOperand(1); - +isImpliedCondICmps(CmpPredicate LPred, const Value *L0, const Value *L1, + CmpPredicate RPred, const Value *R0, const Value *R1, + const DataLayout &DL, bool LHSIsTrue) { // The rest of the logic assumes the LHS condition is true. If that's not the // case, invert the predicate to make it so. - CmpPredicate LPred = - LHSIsTrue ? LHS->getCmpPredicate() : LHS->getInverseCmpPredicate(); + if (!LHSIsTrue) + LPred = ICmpInst::getInverseCmpPredicate(LPred); // We can have non-canonical operands, so try to normalize any common operand // to L0/R0. @@ -9314,9 +9312,15 @@ llvm::isImpliedCondition(const Value *LHS, CmpPredicate RHSPred, LHSIsTrue = !LHSIsTrue; // Both LHS and RHS are icmps. - const ICmpInst *LHSCmp = dyn_cast(LHS); - if (LHSCmp) - return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue); + if (const auto *LHSCmp = dyn_cast(LHS)) + return isImpliedCondICmps(LHSCmp->getCmpPredicate(), LHSCmp->getOperand(0), + LHSCmp->getOperand(1), RHSPred, RHSOp0, RHSOp1, + DL, LHSIsTrue); + const Value *V; + if (match(LHS, m_NUWTrunc(m_Value(V)))) + return isImpliedCondICmps(CmpInst::ICMP_NE, V, + ConstantInt::get(V->getType(), 0), RHSPred, + RHSOp0, RHSOp1, DL, LHSIsTrue); /// The LHS should be an 'or', 'and', or a 'select' instruction. We expect /// the RHS to be an icmp. @@ -9354,6 +9358,15 @@ std::optional llvm::isImpliedCondition(const Value *LHS, const Value *RHS, return std::nullopt; } + const Value *V; + if (match(RHS, m_NUWTrunc(m_Value(V)))) { + if (auto Implied = isImpliedCondition(LHS, CmpInst::ICMP_NE, V, + ConstantInt::get(V->getType(), 0), DL, + LHSIsTrue, Depth)) + return InvertRHS ? !*Implied : *Implied; + return std::nullopt; + } + if (Depth == MaxAnalysisRecursionDepth) return std::nullopt; diff --git a/llvm/test/Transforms/InstCombine/select-safe-impliedcond-transforms.ll b/llvm/test/Transforms/InstCombine/select-safe-impliedcond-transforms.ll index 4b66b71773683..ba34930dc14c1 100644 --- a/llvm/test/Transforms/InstCombine/select-safe-impliedcond-transforms.ll +++ b/llvm/test/Transforms/InstCombine/select-safe-impliedcond-transforms.ll @@ -176,3 +176,85 @@ define i1 @a_false_implies_b_false2_comm(i8 %z, i1 %X, i1 %Y) { %res = or i1 %sel, %a ret i1 %res } + +define i1 @trunc_nuw_implies_icmp_eq(i8 %x, i1 %c) { +; CHECK-LABEL: @trunc_nuw_implies_icmp_eq( +; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[TRUNC]], i1 true, i1 [[C:%.*]] +; CHECK-NEXT: ret i1 [[SEL2]] +; + %trunc = trunc nuw i8 %x to i1 + %cmp = icmp eq i8 %x, 0 + %sel1 = select i1 %cmp, i1 %c, i1 false + %sel2 = select i1 %trunc, i1 true, i1 %sel1 + ret i1 %sel2 +} + +define i1 @icmp_eq_implies_trunc_nuw(i8 %x, i1 %c) { +; CHECK-LABEL: @icmp_eq_implies_trunc_nuw( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], 0 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP]], i1 true, i1 [[C:%.*]] +; CHECK-NEXT: ret i1 [[SEL2]] +; + %trunc = trunc nuw i8 %x to i1 + %cmp = icmp eq i8 %x, 0 + %sel1 = select i1 %trunc, i1 %c, i1 false + %sel2 = select i1 %cmp, i1 true, i1 %sel1 + ret i1 %sel2 +} + +define <2 x i1> @trunc_nuw_implies_icmp_eq_vec(<2 x i8> %x, <2 x i1> %c) { +; CHECK-LABEL: @trunc_nuw_implies_icmp_eq_vec( +; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw <2 x i8> [[X:%.*]] to <2 x i1> +; CHECK-NEXT: [[SEL2:%.*]] = select <2 x i1> [[TRUNC]], <2 x i1> splat (i1 true), <2 x i1> [[C:%.*]] +; CHECK-NEXT: ret <2 x i1> [[SEL2]] +; + %trunc = trunc nuw <2 x i8> %x to <2 x i1> + %cmp = icmp eq <2 x i8> %x, splat (i8 0) + %sel1 = select <2 x i1> %cmp, <2 x i1> %c, <2 x i1> splat (i1 false) + %sel2 = select <2 x i1> %trunc, <2 x i1> splat (i1 true), <2 x i1> %sel1 + ret <2 x i1> %sel2 +} + +define <2 x i1> @icmp_eq_implies_trunc_nuw_vec(<2 x i8> %x, <2 x i1> %c) { +; CHECK-LABEL: @icmp_eq_implies_trunc_nuw_vec( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i8> [[X:%.*]], zeroinitializer +; CHECK-NEXT: [[SEL2:%.*]] = select <2 x i1> [[CMP]], <2 x i1> splat (i1 true), <2 x i1> [[C:%.*]] +; CHECK-NEXT: ret <2 x i1> [[SEL2]] +; + %trunc = trunc nuw <2 x i8> %x to <2 x i1> + %cmp = icmp eq <2 x i8> %x, splat (i8 0) + %sel1 = select <2 x i1> %trunc, <2 x i1> %c, <2 x i1> splat (i1 false) + %sel2 = select <2 x i1> %cmp, <2 x i1> splat (i1 true), <2 x i1> %sel1 + ret <2 x i1> %sel2 +} + +define i1 @neg_trunc_implies_icmp_eq(i8 %x, i1 %c) { +; CHECK-LABEL: @neg_trunc_implies_icmp_eq( +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], 0 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP]], i1 [[C:%.*]], i1 false +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[TRUNC]], i1 true, i1 [[SEL1]] +; CHECK-NEXT: ret i1 [[SEL2]] +; + %trunc = trunc i8 %x to i1 + %cmp = icmp eq i8 %x, 0 + %sel1 = select i1 %cmp, i1 %c, i1 false + %sel2 = select i1 %trunc, i1 true, i1 %sel1 + ret i1 %sel2 +} + +define i1 @neg_icmp_eq_implies_trunc(i8 %x, i1 %c) { +; CHECK-LABEL: @neg_icmp_eq_implies_trunc( +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], 0 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[TRUNC]], i1 [[C:%.*]], i1 false +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP]], i1 true, i1 [[SEL1]] +; CHECK-NEXT: ret i1 [[SEL2]] +; + %trunc = trunc i8 %x to i1 + %cmp = icmp eq i8 %x, 0 + %sel1 = select i1 %trunc, i1 %c, i1 false + %sel2 = select i1 %cmp, i1 true, i1 %sel1 + ret i1 %sel2 +}