Skip to content

Commit 1987b50

Browse files
committed
[InstCombine] Canonicalize signed saturated additions with positive numbers only
https://alive2.llvm.org/ce/z/YGT5SN This is tricky because with positive numbers, we only go up, so we can in fact always hit the signed_max boundary. This is important because the intrinsic we use has the behavior of going the OTHER way, aka clamp to INT_MIN if it goes in that direction. And the range checking we do only works for positive numbers. Because of this issue, we can only do this for constants as well.
1 parent f0f8e12 commit 1987b50

File tree

3 files changed

+137
-21
lines changed

3 files changed

+137
-21
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,10 +1027,9 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
10271027
return Result;
10281028
}
10291029

1030-
static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
1031-
InstCombiner::BuilderTy &Builder) {
1032-
if (!Cmp->hasOneUse())
1033-
return nullptr;
1030+
static Value *
1031+
canonicalizeSaturatedAddUnsigned(ICmpInst *Cmp, Value *TVal, Value *FVal,
1032+
InstCombiner::BuilderTy &Builder) {
10341033

10351034
// Match unsigned saturated add with constant.
10361035
Value *Cmp0 = Cmp->getOperand(0);
@@ -1052,8 +1051,7 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
10521051
// uge -1 is canonicalized to eq -1 and requires special handling
10531052
// (a == -1) ? -1 : a + 1 -> uadd.sat(a, 1)
10541053
if (Pred == ICmpInst::ICMP_EQ) {
1055-
if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) &&
1056-
match(Cmp1, m_AllOnes())) {
1054+
if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) && Cmp1 == TVal) {
10571055
return Builder.CreateBinaryIntrinsic(
10581056
Intrinsic::uadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), 1));
10591057
}
@@ -1130,6 +1128,107 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
11301128
return nullptr;
11311129
}
11321130

1131+
static Value *canonicalizeSaturatedAddSigned(ICmpInst *Cmp, Value *TVal,
1132+
Value *FVal,
1133+
InstCombiner::BuilderTy &Builder) {
1134+
// Match saturated add with constant.
1135+
Value *Cmp0 = Cmp->getOperand(0);
1136+
Value *Cmp1 = Cmp->getOperand(1);
1137+
ICmpInst::Predicate Pred = Cmp->getPredicate();
1138+
Value *X;
1139+
const APInt *C;
1140+
1141+
// Canonicalize INT_MAX to true value of the select.
1142+
if (match(FVal, m_MaxSignedValue())) {
1143+
std::swap(TVal, FVal);
1144+
Pred = CmpInst::getInversePredicate(Pred);
1145+
}
1146+
1147+
if (!match(TVal, m_MaxSignedValue()))
1148+
return nullptr;
1149+
1150+
// sge maximum signed value is canonicalized to eq maximum signed value and
1151+
// requires special handling (a == INT_MAX) ? INT_MAX : a + 1 -> sadd.sat(a,
1152+
// 1)
1153+
if (Pred == ICmpInst::ICMP_EQ) {
1154+
if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) && Cmp1 == TVal) {
1155+
return Builder.CreateBinaryIntrinsic(
1156+
Intrinsic::sadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), 1));
1157+
}
1158+
return nullptr;
1159+
}
1160+
1161+
// (X > Y) ? INT_MAX : (X + C) --> sadd.sat(X, C)
1162+
// (X >= Y) ? INT_MAX : (X + C) --> sadd.sat(X, C)
1163+
// where Y is INT_MAX - C or INT_MAX - C - 1, and C > 0
1164+
if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) &&
1165+
match(FVal, m_Add(m_Specific(Cmp0), m_APIntAllowPoison(C))) &&
1166+
C->isStrictlyPositive()) {
1167+
APInt IntMax =
1168+
APInt::getSignedMaxValue(Cmp1->getType()->getScalarSizeInBits());
1169+
1170+
// For SGE, try to flip to SGT to normalize the comparison constant.
1171+
auto FlippedPredicate = Pred;
1172+
const APInt *FlippedC = nullptr;
1173+
if (Pred == ICmpInst::ICMP_SGE) {
1174+
if (auto Flipped = getFlippedStrictnessPredicateAndConstant(
1175+
Pred, cast<Constant>(Cmp1))) {
1176+
FlippedPredicate = Flipped->first;
1177+
if (auto *CI = dyn_cast<ConstantInt>(Flipped->second))
1178+
FlippedC = &CI->getValue();
1179+
}
1180+
}
1181+
1182+
// Check the pattern: X > INT_MAX - C or X > INT_MAX - C - 1
1183+
bool MatchesSGT =
1184+
(Pred == ICmpInst::ICMP_SGT &&
1185+
(match(Cmp1, m_SpecificIntAllowPoison(IntMax - *C)) ||
1186+
match(Cmp1, m_SpecificIntAllowPoison(IntMax - *C - 1))));
1187+
1188+
// Check if SGE can be canonicalized to match the pattern
1189+
bool MatchesSGE =
1190+
(Pred == ICmpInst::ICMP_SGE && FlippedC &&
1191+
(*FlippedC == IntMax - *C || *FlippedC == IntMax - *C - 1));
1192+
1193+
if (MatchesSGT || MatchesSGE) {
1194+
return Builder.CreateBinaryIntrinsic(Intrinsic::sadd_sat, Cmp0,
1195+
ConstantInt::get(Cmp0->getType(), *C));
1196+
}
1197+
}
1198+
1199+
// Canonicalize predicate to less-than or less-or-equal-than.
1200+
if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
1201+
std::swap(Cmp0, Cmp1);
1202+
Pred = CmpInst::getSwappedPredicate(Pred);
1203+
}
1204+
1205+
if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SLE)
1206+
return nullptr;
1207+
1208+
if (match(Cmp0, m_NSWSub(m_MaxSignedValue(), m_Value(X))) &&
1209+
match(FVal, m_c_Add(m_Specific(X), m_Specific(Cmp1)))) {
1210+
// (INT_MAX - X s< Y) ? INT_MAX : (X + Y) --> sadd.sat(X, Y)
1211+
// (INT_MAX - X s< Y) ? INT_MAX : (Y + X) --> sadd.sat(X, Y)
1212+
return Builder.CreateBinaryIntrinsic(Intrinsic::sadd_sat, X, Cmp1);
1213+
}
1214+
1215+
return nullptr;
1216+
}
1217+
1218+
static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
1219+
InstCombiner::BuilderTy &Builder) {
1220+
if (!Cmp->hasOneUse())
1221+
return nullptr;
1222+
1223+
if (Value *V = canonicalizeSaturatedAddUnsigned(Cmp, TVal, FVal, Builder))
1224+
return V;
1225+
1226+
if (Value *V = canonicalizeSaturatedAddSigned(Cmp, TVal, FVal, Builder))
1227+
return V;
1228+
1229+
return nullptr;
1230+
}
1231+
11331232
/// Try to match patterns with select and subtract as absolute difference.
11341233
static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
11351234
InstCombiner::BuilderTy &Builder) {

llvm/test/Transforms/InstCombine/canonicalize-const-to-bop.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ define i8 @udiv_slt_exact(i8 %x) {
123123
define i8 @canonicalize_icmp_operands(i8 %x) {
124124
; CHECK-LABEL: define i8 @canonicalize_icmp_operands(
125125
; CHECK-SAME: i8 [[X:%.*]]) {
126-
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 119)
127-
; CHECK-NEXT: [[S:%.*]] = add nsw i8 [[TMP1]], 8
126+
; CHECK-NEXT: [[S:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X]], i8 8)
128127
; CHECK-NEXT: ret i8 [[S]]
129128
;
130129
%add = add nsw i8 %x, 8

llvm/test/Transforms/InstCombine/saturating-add-sub.ll

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2378,9 +2378,7 @@ define i8 @sadd_sat_ugt_int_max(i8 %x, i8 %y) {
23782378

23792379
define i8 @sadd_sat_eq_int_max(i8 %x) {
23802380
; CHECK-LABEL: @sadd_sat_eq_int_max(
2381-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], 127
2382-
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], 1
2383-
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i8 127, i8 [[ADD]]
2381+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 1)
23842382
; CHECK-NEXT: ret i8 [[R]]
23852383
;
23862384
%cmp = icmp eq i8 %x, 127
@@ -2391,8 +2389,7 @@ define i8 @sadd_sat_eq_int_max(i8 %x) {
23912389

23922390
define i8 @sadd_sat_constant(i8 %x) {
23932391
; CHECK-LABEL: @sadd_sat_constant(
2394-
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[X:%.*]], i8 117)
2395-
; CHECK-NEXT: [[R:%.*]] = add nsw i8 [[TMP1]], 10
2392+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 10)
23962393
; CHECK-NEXT: ret i8 [[R]]
23972394
;
23982395
%cmp = icmp sge i8 %x, 118
@@ -2611,10 +2608,7 @@ define i8 @sadd_sat_commuted_both(i8 %x, i8 %y) {
26112608

26122609
define i8 @sadd_sat_int_max_minus_x_nsw_slt(i8 %x, i8 %y) {
26132610
; CHECK-LABEL: @sadd_sat_int_max_minus_x_nsw_slt(
2614-
; CHECK-NEXT: [[SUB:%.*]] = sub nsw i8 127, [[X:%.*]]
2615-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[SUB]], [[Y:%.*]]
2616-
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], [[Y]]
2617-
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i8 127, i8 [[ADD]]
2611+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
26182612
; CHECK-NEXT: ret i8 [[R]]
26192613
;
26202614
%sub = sub nsw i8 127, %x
@@ -2626,10 +2620,7 @@ define i8 @sadd_sat_int_max_minus_x_nsw_slt(i8 %x, i8 %y) {
26262620

26272621
define i8 @sadd_sat_int_max_minus_x_nsw_sge_commuted(i8 %x, i8 %y) {
26282622
; CHECK-LABEL: @sadd_sat_int_max_minus_x_nsw_sge_commuted(
2629-
; CHECK-NEXT: [[SUB:%.*]] = sub nsw i8 127, [[X:%.*]]
2630-
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i8 [[Y:%.*]], [[SUB]]
2631-
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], [[Y]]
2632-
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP_NOT]], i8 [[ADD]], i8 127
2623+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
26332624
; CHECK-NEXT: ret i8 [[R]]
26342625
;
26352626
%sub = sub nsw i8 127, %x
@@ -2653,3 +2644,30 @@ define i8 @sadd_sat_int_max_minus_x_no_nsw_neg(i8 %x, i8 %y) {
26532644
%r = select i1 %cmp, i8 127, i8 %add
26542645
ret i8 %r
26552646
}
2647+
2648+
define i8 @neg_no_nsw(i8 %x, i8 %y) {
2649+
; CHECK-LABEL: @neg_no_nsw(
2650+
; CHECK-NEXT: [[ADD:%.*]] = sub i8 127, [[Y:%.*]]
2651+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[ADD]]
2652+
; CHECK-NEXT: [[D:%.*]] = add i8 [[X]], [[Y]]
2653+
; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 127, i8 [[D]]
2654+
; CHECK-NEXT: ret i8 [[S]]
2655+
;
2656+
%add = sub i8 127, %y
2657+
%cmp = icmp sgt i8 %x, %add
2658+
%d = add i8 %x, %y
2659+
%s = select i1 %cmp, i8 127, i8 %d
2660+
ret i8 %s
2661+
}
2662+
2663+
define i8 @neg_neg_constant(i8 %x, i8 %y) {
2664+
; CHECK-LABEL: @neg_neg_constant(
2665+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[X:%.*]], i8 -1)
2666+
; CHECK-NEXT: [[S:%.*]] = and i8 [[TMP1]], 127
2667+
; CHECK-NEXT: ret i8 [[S]]
2668+
;
2669+
%cmp = icmp sgt i8 %x, -2
2670+
%d = add i8 %x, -128
2671+
%s = select i1 %cmp, i8 127, i8 %d
2672+
ret i8 %s
2673+
}

0 commit comments

Comments
 (0)