diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h index f927838c843ac..faaa6d319f437 100644 --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -719,6 +719,88 @@ bool isGuaranteedToExecuteForEveryIteration(const Instruction *I, /// getGuaranteedNonPoisonOp. bool propagatesPoison(const Use &PoisonOp); +/// Enumerates all operands of \p I that are guaranteed to not be undef or +/// poison. If the callback \p Handle returns true, stop processing and return +/// true. Otherwise, return false. +template +bool handleGuaranteedWellDefinedOps(const Instruction *I, + const CallableT &Handle) { + switch (I->getOpcode()) { + case Instruction::Store: + if (Handle(cast(I)->getPointerOperand())) + return true; + break; + + case Instruction::Load: + if (Handle(cast(I)->getPointerOperand())) + return true; + break; + + // Since dereferenceable attribute imply noundef, atomic operations + // also implicitly have noundef pointers too + case Instruction::AtomicCmpXchg: + if (Handle(cast(I)->getPointerOperand())) + return true; + break; + + case Instruction::AtomicRMW: + if (Handle(cast(I)->getPointerOperand())) + return true; + break; + + case Instruction::Call: + case Instruction::Invoke: { + const CallBase *CB = cast(I); + if (CB->isIndirectCall() && Handle(CB->getCalledOperand())) + return true; + for (unsigned i = 0; i < CB->arg_size(); ++i) + if ((CB->paramHasAttr(i, Attribute::NoUndef) || + CB->paramHasAttr(i, Attribute::Dereferenceable) || + CB->paramHasAttr(i, Attribute::DereferenceableOrNull)) && + Handle(CB->getArgOperand(i))) + return true; + break; + } + case Instruction::Ret: + if (I->getFunction()->hasRetAttribute(Attribute::NoUndef) && + Handle(I->getOperand(0))) + return true; + break; + case Instruction::Switch: + if (Handle(cast(I)->getCondition())) + return true; + break; + case Instruction::Br: { + auto *BR = cast(I); + if (BR->isConditional() && Handle(BR->getCondition())) + return true; + break; + } + default: + break; + } + + return false; +} + +/// Enumerates all operands of \p I that are guaranteed to not be poison. +template +bool handleGuaranteedNonPoisonOps(const Instruction *I, + const CallableT &Handle) { + if (handleGuaranteedWellDefinedOps(I, Handle)) + return true; + switch (I->getOpcode()) { + // Divisors of these operations are allowed to be partially undef. + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + return Handle(I->getOperand(1)); + default: + return false; + } +} + /// Return true if the given instruction must trigger undefined behavior /// when I is executed with any operands which appear in KnownPoison holding /// a poison value at the point of execution. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index cdf7f052943c8..b3d4164984b0d 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -8213,88 +8213,6 @@ bool llvm::propagatesPoison(const Use &PoisonOp) { } } -/// Enumerates all operands of \p I that are guaranteed to not be undef or -/// poison. If the callback \p Handle returns true, stop processing and return -/// true. Otherwise, return false. -template -static bool handleGuaranteedWellDefinedOps(const Instruction *I, - const CallableT &Handle) { - switch (I->getOpcode()) { - case Instruction::Store: - if (Handle(cast(I)->getPointerOperand())) - return true; - break; - - case Instruction::Load: - if (Handle(cast(I)->getPointerOperand())) - return true; - break; - - // Since dereferenceable attribute imply noundef, atomic operations - // also implicitly have noundef pointers too - case Instruction::AtomicCmpXchg: - if (Handle(cast(I)->getPointerOperand())) - return true; - break; - - case Instruction::AtomicRMW: - if (Handle(cast(I)->getPointerOperand())) - return true; - break; - - case Instruction::Call: - case Instruction::Invoke: { - const CallBase *CB = cast(I); - if (CB->isIndirectCall() && Handle(CB->getCalledOperand())) - return true; - for (unsigned i = 0; i < CB->arg_size(); ++i) - if ((CB->paramHasAttr(i, Attribute::NoUndef) || - CB->paramHasAttr(i, Attribute::Dereferenceable) || - CB->paramHasAttr(i, Attribute::DereferenceableOrNull)) && - Handle(CB->getArgOperand(i))) - return true; - break; - } - case Instruction::Ret: - if (I->getFunction()->hasRetAttribute(Attribute::NoUndef) && - Handle(I->getOperand(0))) - return true; - break; - case Instruction::Switch: - if (Handle(cast(I)->getCondition())) - return true; - break; - case Instruction::Br: { - auto *BR = cast(I); - if (BR->isConditional() && Handle(BR->getCondition())) - return true; - break; - } - default: - break; - } - - return false; -} - -/// Enumerates all operands of \p I that are guaranteed to not be poison. -template -static bool handleGuaranteedNonPoisonOps(const Instruction *I, - const CallableT &Handle) { - if (handleGuaranteedWellDefinedOps(I, Handle)) - return true; - switch (I->getOpcode()) { - // Divisors of these operations are allowed to be partially undef. - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::URem: - case Instruction::SRem: - return Handle(I->getOperand(1)); - default: - return false; - } -} - bool llvm::mustTriggerUB(const Instruction *I, const SmallPtrSetImpl &KnownPoison) { return handleGuaranteedNonPoisonOps( diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 58d705eb6aa96..f60df1d836c39 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -1083,11 +1083,52 @@ void State::addInfoForInductions(BasicBlock &BB) { } } +static void addNonPoisonValueFactRecursive( + Value *V, SmallPtrSet &Visited, + function_ref AddFact) { + auto *I = dyn_cast(V); + if (!I) + return; + + if (Visited.contains(I)) + return; + Visited.insert(I); + + if (auto *II = dyn_cast(V)) { + Intrinsic::ID IID = II->getIntrinsicID(); + switch (IID) { + case Intrinsic::umin: + case Intrinsic::umax: + case Intrinsic::smin: + case Intrinsic::smax: { + ICmpInst::Predicate Pred = + ICmpInst::getNonStrictPredicate(MinMaxIntrinsic::getPredicate(IID)); + AddFact(Pred, II, II->getArgOperand(0)); + AddFact(Pred, II, II->getArgOperand(1)); + break; + } + case Intrinsic::abs: { + if (cast(II->getArgOperand(1))->isOne()) + AddFact(CmpInst::ICMP_SGE, II, ConstantInt::get(II->getType(), 0)); + AddFact(CmpInst::ICMP_SGE, II, II->getArgOperand(0)); + break; + } + default: + break; + } + } + + for (auto &Op : I->operands()) + if (isa(Op) && propagatesPoison(Op)) + addNonPoisonValueFactRecursive(Op.get(), Visited, AddFact); +} + void State::addInfoFor(BasicBlock &BB) { addInfoForInductions(BB); // True as long as long as the current instruction is guaranteed to execute. bool GuaranteedToExecute = true; + SmallPtrSet Visited; // Queue conditions and assumes. for (Instruction &I : BB) { if (auto *Cmp = dyn_cast(&I)) { @@ -1120,14 +1161,10 @@ void State::addInfoFor(BasicBlock &BB) { } break; } - // Enqueue ssub_with_overflow for simplification. + // Enqueue intrinsic for simplification. case Intrinsic::ssub_with_overflow: case Intrinsic::ucmp: case Intrinsic::scmp: - WorkList.push_back( - FactOrCheck::getCheck(DT.getNode(&BB), cast(&I))); - break; - // Enqueue the intrinsics to add extra info. case Intrinsic::umin: case Intrinsic::umax: case Intrinsic::smin: @@ -1135,16 +1172,22 @@ void State::addInfoFor(BasicBlock &BB) { // TODO: handle llvm.abs as well WorkList.push_back( FactOrCheck::getCheck(DT.getNode(&BB), cast(&I))); - // TODO: Check if it is possible to instead only added the min/max facts - // when simplifying uses of the min/max intrinsics. - if (!isGuaranteedNotToBePoison(&I)) - break; - [[fallthrough]]; - case Intrinsic::abs: - WorkList.push_back(FactOrCheck::getInstFact(DT.getNode(&BB), &I)); break; } + if (GuaranteedToExecute) { + auto AddFact = [&](CmpPredicate Pred, Value *A, Value *B) { + WorkList.emplace_back( + FactOrCheck::getConditionFact(DT.getNode(&BB), Pred, A, B)); + }; + + handleGuaranteedWellDefinedOps(&I, [&](const Value *Op) { + addNonPoisonValueFactRecursive(const_cast(Op), Visited, + AddFact); + return false; + }); + } + GuaranteedToExecute &= isGuaranteedToTransferExecutionToSuccessor(&I); } @@ -1385,10 +1428,39 @@ static void generateReproducer(CmpInst *Cond, Module *M, assert(!verifyFunction(*F, &dbgs())); } +static void +removeEntryFromStack(const StackEntry &E, ConstraintInfo &Info, + SmallVectorImpl &DFSInStack, + SmallVectorImpl *ReproducerCondStack) { + Info.popLastConstraint(E.IsSigned); + // Remove variables in the system that went out of scope. + auto &Mapping = Info.getValue2Index(E.IsSigned); + for (Value *V : E.ValuesToRelease) + Mapping.erase(V); + Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size()); + DFSInStack.pop_back(); + if (ReproducerCondStack) + ReproducerCondStack->pop_back(); +} + static std::optional checkCondition(CmpInst::Predicate Pred, Value *A, Value *B, Instruction *CheckInst, ConstraintInfo &Info) { LLVM_DEBUG(dbgs() << "Checking " << *CheckInst << "\n"); + SmallVector DFSInStack; + SmallPtrSet Visited; + auto StackRestorer = make_scope_exit([&]() { + while (!DFSInStack.empty()) + removeEntryFromStack(DFSInStack.back(), Info, DFSInStack, nullptr); + }); + auto AddFact = [&](CmpPredicate Pred, Value *A, Value *B) { + Info.addFact(Pred, A, B, 0, 0, DFSInStack); + if (ICmpInst::isRelational(Pred)) + Info.transferToOtherSystem(Pred, A, B, 0, 0, DFSInStack); + }; + + addNonPoisonValueFactRecursive(A, Visited, AddFact); + addNonPoisonValueFactRecursive(B, Visited, AddFact); auto R = Info.getConstraintForSolving(Pred, A, B); if (R.empty() || !R.isValid(Info)){ @@ -1517,22 +1589,6 @@ static bool checkAndReplaceCmp(CmpIntrinsic *I, ConstraintInfo &Info, return false; } -static void -removeEntryFromStack(const StackEntry &E, ConstraintInfo &Info, - Module *ReproducerModule, - SmallVectorImpl &ReproducerCondStack, - SmallVectorImpl &DFSInStack) { - Info.popLastConstraint(E.IsSigned); - // Remove variables in the system that went out of scope. - auto &Mapping = Info.getValue2Index(E.IsSigned); - for (Value *V : E.ValuesToRelease) - Mapping.erase(V); - Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size()); - DFSInStack.pop_back(); - if (ReproducerModule) - ReproducerCondStack.pop_back(); -} - /// Check if either the first condition of an AND or OR is implied by the /// (negated in case of OR) second condition or vice versa. static bool checkOrAndOpImpliedByOther( @@ -1554,12 +1610,13 @@ static bool checkOrAndOpImpliedByOther( // Remove entries again. while (OldSize < DFSInStack.size()) { StackEntry E = DFSInStack.back(); - removeEntryFromStack(E, Info, ReproducerModule, ReproducerCondStack, - DFSInStack); + removeEntryFromStack(E, Info, DFSInStack, + ReproducerModule ? &ReproducerCondStack : nullptr); } }); bool IsOr = match(JoinOp, m_LogicalOr()); SmallVector Worklist({JoinOp->getOperand(OtherOpIdx)}); + SmallPtrSet Visited; // Do a traversal of the AND/OR tree to add facts from leaf compares. while (!Worklist.empty()) { Value *Val = Worklist.pop_back_val(); @@ -1571,6 +1628,15 @@ static bool checkOrAndOpImpliedByOther( Pred = CmpInst::getInversePredicate(Pred); // Optimistically add fact from the other compares in the AND/OR. Info.addFact(Pred, LHS, RHS, CB.NumIn, CB.NumOut, DFSInStack); + auto AddFact = [&](CmpPredicate Pred, Value *A, Value *B) { + Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack); + if (ICmpInst::isRelational(Pred)) + Info.transferToOtherSystem(Pred, A, B, CB.NumIn, CB.NumOut, + DFSInStack); + }; + + addNonPoisonValueFactRecursive(LHS, Visited, AddFact); + addNonPoisonValueFactRecursive(RHS, Visited, AddFact); continue; } if (IsOr ? match(Val, m_LogicalOr(m_Value(LHS), m_Value(RHS))) @@ -1807,8 +1873,8 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI, Info.getValue2Index(E.IsSigned)); dbgs() << "\n"; }); - removeEntryFromStack(E, Info, ReproducerModule.get(), ReproducerCondStack, - DFSInStack); + removeEntryFromStack(E, Info, DFSInStack, + ReproducerModule ? &ReproducerCondStack : nullptr); } // For a block, check if any CmpInsts become known based on the current set @@ -1879,25 +1945,6 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI, }; CmpPredicate Pred; - if (!CB.isConditionFact()) { - Value *X; - if (match(CB.Inst, m_Intrinsic(m_Value(X)))) { - // If is_int_min_poison is true then we may assume llvm.abs >= 0. - if (cast(CB.Inst->getOperand(1))->isOne()) - AddFact(CmpInst::ICMP_SGE, CB.Inst, - ConstantInt::get(CB.Inst->getType(), 0)); - AddFact(CmpInst::ICMP_SGE, CB.Inst, X); - continue; - } - - if (auto *MinMax = dyn_cast(CB.Inst)) { - Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate()); - AddFact(Pred, MinMax, MinMax->getLHS()); - AddFact(Pred, MinMax, MinMax->getRHS()); - continue; - } - } - Value *A = nullptr, *B = nullptr; if (CB.isConditionFact()) { Pred = CB.Cond.Pred; @@ -1922,6 +1969,10 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI, assert(Matched && "Must have an assume intrinsic with a icmp operand"); } AddFact(Pred, A, B); + // Now both A and B is guaranteed not to be poison. + SmallPtrSet Visited; + addNonPoisonValueFactRecursive(A, Visited, AddFact); + addNonPoisonValueFactRecursive(B, Visited, AddFact); } if (ReproducerModule && !ReproducerModule->functions().empty()) { diff --git a/llvm/test/Transforms/ConstraintElimination/minmax.ll b/llvm/test/Transforms/ConstraintElimination/minmax.ll index 029b6508a2106..a079e092cb72a 100644 --- a/llvm/test/Transforms/ConstraintElimination/minmax.ll +++ b/llvm/test/Transforms/ConstraintElimination/minmax.ll @@ -306,9 +306,7 @@ define i1 @smin_branchless(i32 %x, i32 %y) { ; CHECK-SAME: (i32 [[X:%.*]], i32 [[Y:%.*]]) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]]) -; CHECK-NEXT: [[CMP1:%.*]] = icmp sle i32 [[MIN]], [[X]] -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[MIN]], [[X]] -; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP1]], [[CMP2]] +; CHECK-NEXT: [[RET:%.*]] = xor i1 true, false ; CHECK-NEXT: ret i1 [[RET]] ; entry: