Skip to content

Commit c3ca5fc

Browse files
committed
[SCEV] Preserve divisibility info when creating UMax/SMax expressions.
Currently we generate (S|U)Max(1, Op) for Op >= 1. This may discard divisibility info of Op. This patch rewrites such SMax/UMax expressions to use the lowest common multiplier for all non-constant operands.
1 parent 857961d commit c3ca5fc

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15869,12 +15869,17 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1586915869
To = SE.getUMaxExpr(FromRewritten, RHS);
1587015870
if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
1587115871
EnqueueOperands(UMin);
15872+
if (RHS->isOne())
15873+
ExprsToRewrite.push_back(From);
1587215874
break;
1587315875
case CmpInst::ICMP_SGT:
1587415876
case CmpInst::ICMP_SGE:
1587515877
To = SE.getSMaxExpr(FromRewritten, RHS);
15876-
if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15878+
if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten)) {
1587715879
EnqueueOperands(SMin);
15880+
}
15881+
if (RHS->isOne())
15882+
ExprsToRewrite.push_back(From);
1587815883
break;
1587915884
case CmpInst::ICMP_EQ:
1588015885
if (isa<SCEVConstant>(RHS))
@@ -16005,7 +16010,22 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1600516010
for (const SCEV *Expr : ExprsToRewrite) {
1600616011
const SCEV *RewriteTo = Guards.RewriteMap[Expr];
1600716012
Guards.RewriteMap.erase(Expr);
16008-
Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16013+
const SCEV *Rewritten = Guards.rewrite(RewriteTo);
16014+
16015+
// Try to strengthen divisibility of SMax/UMax expressions coming from >=
16016+
// 1 conditions.
16017+
if (auto *SMax = dyn_cast<SCEVSMaxExpr>(Rewritten)) {
16018+
unsigned MinTrailingZeros = SE.getMinTrailingZeros(SMax->getOperand(1));
16019+
for (const SCEV *Op : drop_begin(SMax->operands(), 2))
16020+
MinTrailingZeros =
16021+
std::min(MinTrailingZeros, SE.getMinTrailingZeros(Op));
16022+
if (MinTrailingZeros != 0)
16023+
Rewritten = SE.getSMaxExpr(
16024+
SE.getConstant(APInt(SMax->getType()->getScalarSizeInBits(), 1)
16025+
.shl(MinTrailingZeros)),
16026+
SMax);
16027+
}
16028+
Guards.RewriteMap.insert({Expr, Rewritten});
1600916029
}
1601016030
}
1601116031
}

llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ define void @umin(i32 noundef %a, i32 noundef %b) {
6161
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
6262
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 2147483646
6363
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
64-
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
64+
; CHECK-NEXT: Loop %for.body: Trip multiple is 2
6565
;
6666
; void umin(unsigned a, unsigned b) {
6767
; a *= 2;
@@ -157,7 +157,7 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
157157
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
158158
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 2147483646
159159
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
160-
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
160+
; CHECK-NEXT: Loop %for.body: Trip multiple is 2
161161
;
162162
; void smin(signed a, signed b) {
163163
; a *= 2;

0 commit comments

Comments
 (0)