-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Reassociate] Conserve nsw/nuw flags when factoring out #125773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When transforming, e.g, A*A+A*B*C+D into A*(A+B*C)+D, we can set the factored out mul to have nuw/nsw iff all of the adds and muls had this flag set.
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-llvm-transforms Author: None (joe-rivos) ChangesWhen transforming, e.g, AA+ABC+D into A(A+B*C)+D, we can set the Full diff: https://github.com/llvm/llvm-project/pull/125773.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index 3b2d2b83ced623..8023385c2ad903 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -119,9 +119,11 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
SmallVectorImpl<reassociate::ValueEntry> &Ops,
reassociate::OverflowTracking Flags);
Value *OptimizeExpression(BinaryOperator *I,
- SmallVectorImpl<reassociate::ValueEntry> &Ops);
+ SmallVectorImpl<reassociate::ValueEntry> &Ops,
+ reassociate::OverflowTracking Flags);
Value *OptimizeAdd(Instruction *I,
- SmallVectorImpl<reassociate::ValueEntry> &Ops);
+ SmallVectorImpl<reassociate::ValueEntry> &Ops,
+ reassociate::OverflowTracking Flags);
Value *OptimizeXor(Instruction *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops);
bool CombineXorOpnd(BasicBlock::iterator It, reassociate::XorOpnd *Opnd1,
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 7cb9bace47bf44..ceae4e1957cf27 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -1174,16 +1174,21 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
///
/// Ops is the top-level list of add operands we're trying to factor.
static void FindSingleUseMultiplyFactors(Value *V,
- SmallVectorImpl<Value*> &Factors) {
+ SmallVectorImpl<Value *> &Factors,
+ bool &AllNUW, bool &AllNSW) {
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul);
if (!BO) {
Factors.push_back(V);
return;
}
+ if (isa<OverflowingBinaryOperator>(BO)) {
+ AllNUW &= BO->hasNoUnsignedWrap();
+ AllNSW &= BO->hasNoSignedWrap();
+ }
// Otherwise, add the LHS and RHS to the list of factors.
- FindSingleUseMultiplyFactors(BO->getOperand(1), Factors);
- FindSingleUseMultiplyFactors(BO->getOperand(0), Factors);
+ FindSingleUseMultiplyFactors(BO->getOperand(1), Factors, AllNUW, AllNSW);
+ FindSingleUseMultiplyFactors(BO->getOperand(0), Factors, AllNUW, AllNSW);
}
/// Optimize a series of operands to an 'and', 'or', or 'xor' instruction.
@@ -1492,7 +1497,9 @@ Value *ReassociatePass::OptimizeXor(Instruction *I,
/// optimizes based on identities. If it can be reduced to a single Value, it
/// is returned, otherwise the Ops list is mutated as necessary.
Value *ReassociatePass::OptimizeAdd(Instruction *I,
- SmallVectorImpl<ValueEntry> &Ops) {
+ SmallVectorImpl<ValueEntry> &Ops,
+ OverflowTracking Flags) {
+
// Scan the operand lists looking for X and -X pairs. If we find any, we
// can simplify expressions like X+-X == 0 and X+~X ==-1. While we're at it,
// scan for any
@@ -1586,8 +1593,11 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
// Keep track of each multiply we see, to avoid triggering on (X*4)+(X*4)
// where they are actually the same multiply.
+ // Also track every use of this factor shares nuw/nsw. This will allow the
+ // use of these flags in the factored value.
unsigned MaxOcc = 0;
Value *MaxOccVal = nullptr;
+ bool MaxOccAllNUW, MaxOccAllNSW = false;
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
BinaryOperator *BOp =
isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul);
@@ -1596,7 +1606,9 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
// Compute all of the factors of this added value.
SmallVector<Value*, 8> Factors;
- FindSingleUseMultiplyFactors(BOp, Factors);
+ bool AllNUW = Flags.HasNUW;
+ bool AllNSW = Flags.HasNSW;
+ FindSingleUseMultiplyFactors(BOp, Factors, AllNUW, AllNSW);
assert(Factors.size() > 1 && "Bad linearize!");
// Add one to FactorOccurrences for each unique factor in this op.
@@ -1608,7 +1620,16 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
unsigned Occ = ++FactorOccurrences[Factor];
if (Occ > MaxOcc) {
MaxOcc = Occ;
- MaxOccVal = Factor;
+ if (MaxOccVal != Factor) {
+ MaxOccVal = Factor;
+ if (Occ == 1) {
+ MaxOccAllNUW = AllNUW;
+ MaxOccAllNSW = AllNSW;
+ } else {
+ MaxOccAllNUW &= AllNUW;
+ MaxOccAllNSW &= AllNSW;
+ }
+ }
}
// If Factor is a negative constant, add the negated value as a factor
@@ -1690,11 +1711,20 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
// A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C))
assert(NumAddedValues > 1 && "Each occurrence should contribute a value");
(void)NumAddedValues;
- if (Instruction *VI = dyn_cast<Instruction>(V))
+ if (Instruction *VI = dyn_cast<Instruction>(V)) {
+ if (isa<OverflowingBinaryOperator>(VI)) {
+ VI->setHasNoUnsignedWrap(MaxOccAllNUW);
+ VI->setHasNoSignedWrap(MaxOccAllNSW);
+ }
RedoInsts.insert(VI);
+ }
// Create the multiply.
Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I->getIterator(), I);
+ if (isa<OverflowingBinaryOperator>(V2)) {
+ V2->setHasNoUnsignedWrap(MaxOccAllNUW);
+ V2->setHasNoSignedWrap(MaxOccAllNSW);
+ }
// Rerun associate on the multiply in case the inner expression turned into
// a multiply. We want to make sure that we keep things in canonical form.
@@ -1890,7 +1920,8 @@ Value *ReassociatePass::OptimizeMul(BinaryOperator *I,
}
Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
- SmallVectorImpl<ValueEntry> &Ops) {
+ SmallVectorImpl<ValueEntry> &Ops,
+ OverflowTracking Flags) {
// Now that we have the linearized expression tree, try to optimize it.
// Start by folding any constants that we found.
const DataLayout &DL = I->getDataLayout();
@@ -1944,7 +1975,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
case Instruction::Add:
case Instruction::FAdd:
- if (Value *Result = OptimizeAdd(I, Ops))
+ if (Value *Result = OptimizeAdd(I, Ops, Flags))
return Result;
break;
@@ -1956,7 +1987,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
}
if (Ops.size() != NumOps)
- return OptimizeExpression(I, Ops);
+ return OptimizeExpression(I, Ops, Flags);
return nullptr;
}
@@ -2305,7 +2336,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
// Now that we have the expression tree in a convenient
// sorted form, optimize it globally if possible.
- if (Value *V = OptimizeExpression(I, Ops)) {
+ if (Value *V = OptimizeExpression(I, Ops, Flags)) {
if (V == I)
// Self-referential expression in unreachable code.
return;
diff --git a/llvm/test/Transforms/Reassociate/basictest.ll b/llvm/test/Transforms/Reassociate/basictest.ll
index 3f4057dd14e7e1..b5f87a8202185c 100644
--- a/llvm/test/Transforms/Reassociate/basictest.ll
+++ b/llvm/test/Transforms/Reassociate/basictest.ll
@@ -293,3 +293,40 @@ define i32 @test17(i32 %X1, i32 %X2, i32 %X3, i32 %X4) {
ret i32 %E
}
+define i32 @test18(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test18(
+; CHECK-NEXT: [[REASS_ADD:%.*]] = add nsw i32 [[X2:%.*]], [[X1:%.*]]
+; CHECK-NEXT: [[REASS_MUL:%.*]] = mul nsw i32 [[REASS_ADD]], 47
+; CHECK-NEXT: ret i32 [[REASS_MUL]]
+;
+ %B = mul nsw i32 %X1, 47
+ %C = mul nsw i32 %X2, 47
+ %D = add nsw i32 %B, %C
+ ret i32 %D
+}
+
+define i32 @test19(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test19(
+; CHECK-NEXT: [[REASS_ADD:%.*]] = add i32 [[X1:%.*]], 67
+; CHECK-NEXT: [[REASS_MUL:%.*]] = mul nsw i32 [[REASS_ADD]], [[X2:%.*]]
+; CHECK-NEXT: ret i32 [[REASS_MUL]]
+;
+ %A = add i32 %X1, 20
+ %B = mul nsw i32 %X2, 47
+ %C = mul nsw i32 %X2, %A
+ %D = add nsw i32 %B, %C
+ ret i32 %D
+}
+
+define i32 @test20(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test20(
+; CHECK-NEXT: [[REASS_ADD:%.*]] = add i32 [[X1:%.*]], 67
+; CHECK-NEXT: [[REASS_MUL:%.*]] = mul i32 [[REASS_ADD]], [[X2:%.*]]
+; CHECK-NEXT: ret i32 [[REASS_MUL]]
+;
+ %A = add i32 %X1, 20
+ %B = mul nuw i32 %X2, 47
+ %C = mul nsw i32 %X2, %A
+ %D = add nsw i32 %B, %C
+ ret i32 %D
+}
|
|
Tagging some reviewers based of history @goldsteinn @akshayrdeodhar edit: nevermind, got the proof wrong |
|
Transformation is invalid in some cases. Woops. Will revise as it can be valid. |
When transforming, e.g, AA+ABC+D into A(A+B*C)+D, we can set the
factored out mul to have nuw/nsw iff all of the adds and muls had this
flag set.
Proof: https://alive2.llvm.org/ce/z/BNb5wS