Skip to content

Commit 775ef66

Browse files
committed
[Reassociate] Conserve nsw/nuw flags when factoring out
When transforming, e.g, A*A+A*B*C+D into A*(A+B*C)+D, we can set the factored out mul to have nuw/nsw iff all of the adds and muls had this flag set.
1 parent 71dd530 commit 775ef66

File tree

3 files changed

+49
-16
lines changed

3 files changed

+49
-16
lines changed

llvm/include/llvm/Transforms/Scalar/Reassociate.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,11 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
119119
SmallVectorImpl<reassociate::ValueEntry> &Ops,
120120
reassociate::OverflowTracking Flags);
121121
Value *OptimizeExpression(BinaryOperator *I,
122-
SmallVectorImpl<reassociate::ValueEntry> &Ops);
122+
SmallVectorImpl<reassociate::ValueEntry> &Ops,
123+
reassociate::OverflowTracking Flags);
123124
Value *OptimizeAdd(Instruction *I,
124-
SmallVectorImpl<reassociate::ValueEntry> &Ops);
125+
SmallVectorImpl<reassociate::ValueEntry> &Ops,
126+
reassociate::OverflowTracking Flags);
125127
Value *OptimizeXor(Instruction *I,
126128
SmallVectorImpl<reassociate::ValueEntry> &Ops);
127129
bool CombineXorOpnd(BasicBlock::iterator It, reassociate::XorOpnd *Opnd1,

llvm/lib/Transforms/Scalar/Reassociate.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,16 +1174,21 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
11741174
///
11751175
/// Ops is the top-level list of add operands we're trying to factor.
11761176
static void FindSingleUseMultiplyFactors(Value *V,
1177-
SmallVectorImpl<Value*> &Factors) {
1177+
SmallVectorImpl<Value *> &Factors,
1178+
bool &AllNUW, bool &AllNSW) {
11781179
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul);
11791180
if (!BO) {
11801181
Factors.push_back(V);
11811182
return;
11821183
}
11831184

1185+
if (isa<OverflowingBinaryOperator>(BO)) {
1186+
AllNUW &= BO->hasNoUnsignedWrap();
1187+
AllNSW &= BO->hasNoSignedWrap();
1188+
}
11841189
// Otherwise, add the LHS and RHS to the list of factors.
1185-
FindSingleUseMultiplyFactors(BO->getOperand(1), Factors);
1186-
FindSingleUseMultiplyFactors(BO->getOperand(0), Factors);
1190+
FindSingleUseMultiplyFactors(BO->getOperand(1), Factors, AllNUW, AllNSW);
1191+
FindSingleUseMultiplyFactors(BO->getOperand(0), Factors, AllNUW, AllNSW);
11871192
}
11881193

11891194
/// Optimize a series of operands to an 'and', 'or', or 'xor' instruction.
@@ -1492,7 +1497,9 @@ Value *ReassociatePass::OptimizeXor(Instruction *I,
14921497
/// optimizes based on identities. If it can be reduced to a single Value, it
14931498
/// is returned, otherwise the Ops list is mutated as necessary.
14941499
Value *ReassociatePass::OptimizeAdd(Instruction *I,
1495-
SmallVectorImpl<ValueEntry> &Ops) {
1500+
SmallVectorImpl<ValueEntry> &Ops,
1501+
OverflowTracking Flags) {
1502+
14961503
// Scan the operand lists looking for X and -X pairs. If we find any, we
14971504
// can simplify expressions like X+-X == 0 and X+~X ==-1. While we're at it,
14981505
// scan for any
@@ -1586,8 +1593,11 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
15861593

15871594
// Keep track of each multiply we see, to avoid triggering on (X*4)+(X*4)
15881595
// where they are actually the same multiply.
1596+
// Also track every use of this factor shares nuw/nsw. This will allow the
1597+
// use of these flags in the factored value.
15891598
unsigned MaxOcc = 0;
15901599
Value *MaxOccVal = nullptr;
1600+
bool MaxOccAllNUW, MaxOccAllNSW = false;
15911601
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
15921602
BinaryOperator *BOp =
15931603
isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul);
@@ -1596,7 +1606,9 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
15961606

15971607
// Compute all of the factors of this added value.
15981608
SmallVector<Value*, 8> Factors;
1599-
FindSingleUseMultiplyFactors(BOp, Factors);
1609+
bool AllNUW = Flags.HasNUW;
1610+
bool AllNSW = Flags.HasNSW;
1611+
FindSingleUseMultiplyFactors(BOp, Factors, AllNUW, AllNSW);
16001612
assert(Factors.size() > 1 && "Bad linearize!");
16011613

16021614
// Add one to FactorOccurrences for each unique factor in this op.
@@ -1608,7 +1620,16 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
16081620
unsigned Occ = ++FactorOccurrences[Factor];
16091621
if (Occ > MaxOcc) {
16101622
MaxOcc = Occ;
1611-
MaxOccVal = Factor;
1623+
if (MaxOccVal != Factor) {
1624+
MaxOccVal = Factor;
1625+
if (Occ == 1) {
1626+
MaxOccAllNUW = AllNUW;
1627+
MaxOccAllNSW = AllNSW;
1628+
} else {
1629+
MaxOccAllNUW &= AllNUW;
1630+
MaxOccAllNSW &= AllNSW;
1631+
}
1632+
}
16121633
}
16131634

16141635
// If Factor is a negative constant, add the negated value as a factor
@@ -1690,11 +1711,20 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
16901711
// A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C))
16911712
assert(NumAddedValues > 1 && "Each occurrence should contribute a value");
16921713
(void)NumAddedValues;
1693-
if (Instruction *VI = dyn_cast<Instruction>(V))
1714+
if (Instruction *VI = dyn_cast<Instruction>(V)) {
1715+
if (isa<OverflowingBinaryOperator>(VI)) {
1716+
VI->setHasNoUnsignedWrap(MaxOccAllNUW);
1717+
VI->setHasNoSignedWrap(MaxOccAllNSW);
1718+
}
16941719
RedoInsts.insert(VI);
1720+
}
16951721

16961722
// Create the multiply.
16971723
Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I->getIterator(), I);
1724+
if (isa<OverflowingBinaryOperator>(V2)) {
1725+
V2->setHasNoUnsignedWrap(MaxOccAllNUW);
1726+
V2->setHasNoSignedWrap(MaxOccAllNSW);
1727+
}
16981728

16991729
// Rerun associate on the multiply in case the inner expression turned into
17001730
// a multiply. We want to make sure that we keep things in canonical form.
@@ -1890,7 +1920,8 @@ Value *ReassociatePass::OptimizeMul(BinaryOperator *I,
18901920
}
18911921

18921922
Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
1893-
SmallVectorImpl<ValueEntry> &Ops) {
1923+
SmallVectorImpl<ValueEntry> &Ops,
1924+
OverflowTracking Flags) {
18941925
// Now that we have the linearized expression tree, try to optimize it.
18951926
// Start by folding any constants that we found.
18961927
const DataLayout &DL = I->getDataLayout();
@@ -1944,7 +1975,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
19441975

19451976
case Instruction::Add:
19461977
case Instruction::FAdd:
1947-
if (Value *Result = OptimizeAdd(I, Ops))
1978+
if (Value *Result = OptimizeAdd(I, Ops, Flags))
19481979
return Result;
19491980
break;
19501981

@@ -1956,7 +1987,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
19561987
}
19571988

19581989
if (Ops.size() != NumOps)
1959-
return OptimizeExpression(I, Ops);
1990+
return OptimizeExpression(I, Ops, Flags);
19601991
return nullptr;
19611992
}
19621993

@@ -2305,7 +2336,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
23052336

23062337
// Now that we have the expression tree in a convenient
23072338
// sorted form, optimize it globally if possible.
2308-
if (Value *V = OptimizeExpression(I, Ops)) {
2339+
if (Value *V = OptimizeExpression(I, Ops, Flags)) {
23092340
if (V == I)
23102341
// Self-referential expression in unreachable code.
23112342
return;

llvm/test/Transforms/Reassociate/basictest.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,8 @@ define i32 @test17(i32 %X1, i32 %X2, i32 %X3, i32 %X4) {
295295

296296
define i32 @test18(i32 %X1, i32 %X2) {
297297
; CHECK-LABEL: @test18(
298-
; CHECK-NEXT: [[REASS_ADD:%.*]] = add i32 [[X2:%.*]], [[X1:%.*]]
299-
; CHECK-NEXT: [[REASS_MUL:%.*]] = mul i32 [[REASS_ADD]], 47
298+
; CHECK-NEXT: [[REASS_ADD:%.*]] = add nsw i32 [[X2:%.*]], [[X1:%.*]]
299+
; CHECK-NEXT: [[REASS_MUL:%.*]] = mul nsw i32 [[REASS_ADD]], 47
300300
; CHECK-NEXT: ret i32 [[REASS_MUL]]
301301
;
302302
%B = mul nsw i32 %X1, 47
@@ -308,7 +308,7 @@ define i32 @test18(i32 %X1, i32 %X2) {
308308
define i32 @test19(i32 %X1, i32 %X2) {
309309
; CHECK-LABEL: @test19(
310310
; CHECK-NEXT: [[REASS_ADD:%.*]] = add i32 [[X1:%.*]], 67
311-
; CHECK-NEXT: [[REASS_MUL:%.*]] = mul i32 [[REASS_ADD]], [[X2:%.*]]
311+
; CHECK-NEXT: [[REASS_MUL:%.*]] = mul nsw i32 [[REASS_ADD]], [[X2:%.*]]
312312
; CHECK-NEXT: ret i32 [[REASS_MUL]]
313313
;
314314
%A = add i32 %X1, 20

0 commit comments

Comments
 (0)