diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h index 23b70164d96a4..a5d137661e11e 100644 --- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h @@ -39,6 +39,7 @@ class Function; class Instruction; class IRBuilderBase; class Value; +struct OverflowTracking; /// A private "module" namespace for types and utilities used by Reassociate. /// These are implementation details and should not be used by clients. @@ -64,17 +65,6 @@ struct Factor { Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {} }; -struct OverflowTracking { - bool HasNUW = true; - bool HasNSW = true; - bool AllKnownNonNegative = true; - bool AllKnownNonZero = true; - // Note: AllKnownNonNegative can be true in a case where one of the operands - // is negative, but one the operators is not NSW. AllKnownNonNegative should - // not be used independently of HasNSW - OverflowTracking() = default; -}; - class XorOpnd; } // end namespace reassociate @@ -115,7 +105,7 @@ class ReassociatePass : public PassInfoMixin { void ReassociateExpression(BinaryOperator *I); void RewriteExprTree(BinaryOperator *I, SmallVectorImpl &Ops, - reassociate::OverflowTracking Flags); + OverflowTracking Flags); Value *OptimizeExpression(BinaryOperator *I, SmallVectorImpl &Ops); Value *OptimizeAdd(Instruction *I, diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h index db064e1f41f02..fa26446f9492a 100644 --- a/llvm/include/llvm/Transforms/Utils/Local.h +++ b/llvm/include/llvm/Transforms/Utils/Local.h @@ -556,6 +556,30 @@ Value *invertCondition(Value *Condition); /// function, explicitly materialize the maximal set in the IR. bool inferAttributesFromOthers(Function &F); +//===----------------------------------------------------------------------===// +// Helpers to track and update flags on instructions. +// + +struct OverflowTracking { + bool HasNUW = true; + bool HasNSW = true; + + // Note: At the moment, users are responsible to manage AllKnownNonNegative + // and AllKnownNonZero manually. AllKnownNonNegative can be true in a case + // where one of the operands is negative, but one the operators is not NSW. + // AllKnownNonNegative should not be used independently of HasNSW + bool AllKnownNonNegative = true; + bool AllKnownNonZero = true; + + OverflowTracking() = default; + + /// Merge in the no-wrap flags from \p I. + void mergeFlags(Instruction &I); + + /// Apply the no-wrap flags to \p I if applicable. + void applyFlags(Instruction &I); +}; + } // end namespace llvm #endif // LLVM_TRANSFORMS_UTILS_LOCAL_H diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index cb7a9ef9b6711..778a6a012556b 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -382,7 +382,7 @@ using RepeatedValue = std::pair; static bool LinearizeExprTree(Instruction *I, SmallVectorImpl &Ops, ReassociatePass::OrderedSet &ToRedo, - reassociate::OverflowTracking &Flags) { + OverflowTracking &Flags) { assert((isa(I) || isa(I)) && "Expected a UnaryOperator or BinaryOperator!"); LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); @@ -431,10 +431,7 @@ static bool LinearizeExprTree(Instruction *I, // We examine the operands of this binary operator. auto [I, Weight] = Worklist.pop_back_val(); - if (isa(I)) { - Flags.HasNUW &= I->hasNoUnsignedWrap(); - Flags.HasNSW &= I->hasNoSignedWrap(); - } + Flags.mergeFlags(*I); for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); @@ -734,15 +731,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, ExpressionChangedStart->clearSubclassOptionalData(); ExpressionChangedStart->setFastMathFlags(Flags); } else { - ExpressionChangedStart->clearSubclassOptionalData(); - if (ExpressionChangedStart->getOpcode() == Instruction::Add || - (ExpressionChangedStart->getOpcode() == Instruction::Mul && - Flags.AllKnownNonZero)) { - if (Flags.HasNUW) - ExpressionChangedStart->setHasNoUnsignedWrap(); - if (Flags.HasNSW && (Flags.AllKnownNonNegative || Flags.HasNUW)) - ExpressionChangedStart->setHasNoSignedWrap(); - } + Flags.applyFlags(*ExpressionChangedStart); } } diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 3dbd605e19c3a..4d168ce7cf591 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -4362,3 +4362,21 @@ bool llvm::inferAttributesFromOthers(Function &F) { return Changed; } + +void OverflowTracking::mergeFlags(Instruction &I) { + if (isa(&I)) { + HasNUW &= I.hasNoUnsignedWrap(); + HasNSW &= I.hasNoSignedWrap(); + } +} + +void OverflowTracking::applyFlags(Instruction &I) { + I.clearSubclassOptionalData(); + if (I.getOpcode() == Instruction::Add || + (I.getOpcode() == Instruction::Mul && AllKnownNonZero)) { + if (HasNUW) + I.setHasNoUnsignedWrap(); + if (HasNSW && (AllKnownNonNegative || HasNUW)) + I.setHasNoSignedWrap(); + } +}