Skip to content

Commit be31896

Browse files
nikiczmodem
authored andcommitted
Fix incorrect SimplifyWithOpReplaced transform (PR47322)
This is a followup to D86834, which partially fixed this issue in InstSimplify. However, InstCombine repeats the same transform while dropping poison flags -- which does not cover cases where poison is introduced in some other way. The fix here is a bit more comprehensive, because things are quite entangled, and it's hard to only partially address it without regressing optimization. There are really two changes here: * Export the SimplifyWithOpReplaced API from InstSimplify, with an added AllowRefinement flag. For replacements inside the TrueVal we don't actually care whether refinement occurs or not, the replacement is always legal. This part of the transform is now done in InstSimplify only. (It should be noted that the current AllowRefinement check is not sufficient -- that's an issue we need to address separately.) * Change the InstCombine fold to work by temporarily dropping poison generating flags, running the fold and then restoring the flags if it didn't work out. This will ensure that the InstCombine fold is correct as long as the InstSimplify fold is correct. Differential Revision: https://reviews.llvm.org/D87445
1 parent d720e58 commit be31896

File tree

4 files changed

+72
-45
lines changed

4 files changed

+72
-45
lines changed

llvm/include/llvm/Analysis/InstructionSimplify.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ Value *SimplifyFreezeInst(Value *Op, const SimplifyQuery &Q);
268268
Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q,
269269
OptimizationRemarkEmitter *ORE = nullptr);
270270

271+
/// See if V simplifies when its operand Op is replaced with RepOp.
272+
/// AllowRefinement specifies whether the simplification can be a refinement,
273+
/// or whether it needs to be strictly identical.
274+
Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
275+
const SimplifyQuery &Q, bool AllowRefinement);
276+
271277
/// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.
272278
///
273279
/// This first performs a normal RAUW of I with SimpleV. It then recursively

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3810,10 +3810,10 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
38103810
return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
38113811
}
38123812

3813-
/// See if V simplifies when its operand Op is replaced with RepOp.
3814-
static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
3815-
const SimplifyQuery &Q,
3816-
unsigned MaxRecurse) {
3813+
static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
3814+
const SimplifyQuery &Q,
3815+
bool AllowRefinement,
3816+
unsigned MaxRecurse) {
38173817
// Trivial replacement.
38183818
if (V == Op)
38193819
return RepOp;
@@ -3826,20 +3826,19 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
38263826
if (!I)
38273827
return nullptr;
38283828

3829+
// Consider:
3830+
// %cmp = icmp eq i32 %x, 2147483647
3831+
// %add = add nsw i32 %x, 1
3832+
// %sel = select i1 %cmp, i32 -2147483648, i32 %add
3833+
//
3834+
// We can't replace %sel with %add unless we strip away the flags (which will
3835+
// be done in InstCombine).
3836+
// TODO: This is unsound, because it only catches some forms of refinement.
3837+
if (!AllowRefinement && canCreatePoison(I))
3838+
return nullptr;
3839+
38293840
// If this is a binary operator, try to simplify it with the replaced op.
38303841
if (auto *B = dyn_cast<BinaryOperator>(I)) {
3831-
// Consider:
3832-
// %cmp = icmp eq i32 %x, 2147483647
3833-
// %add = add nsw i32 %x, 1
3834-
// %sel = select i1 %cmp, i32 -2147483648, i32 %add
3835-
//
3836-
// We can't replace %sel with %add unless we strip away the flags.
3837-
// TODO: This is an unusual limitation because better analysis results in
3838-
// worse simplification. InstCombine can do this fold more generally
3839-
// by dropping the flags. Remove this fold to save compile-time?
3840-
if (canCreatePoison(I))
3841-
return nullptr;
3842-
38433842
if (MaxRecurse) {
38443843
if (B->getOperand(0) == Op)
38453844
return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q,
@@ -3906,6 +3905,13 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
39063905
return nullptr;
39073906
}
39083907

3908+
Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
3909+
const SimplifyQuery &Q,
3910+
bool AllowRefinement) {
3911+
return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
3912+
RecursionLimit);
3913+
}
3914+
39093915
/// Try to simplify a select instruction when its condition operand is an
39103916
/// integer comparison where one operand of the compare is a constant.
39113917
static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X,
@@ -4017,14 +4023,18 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
40174023
// arms of the select. See if substituting this value into the arm and
40184024
// simplifying the result yields the same value as the other arm.
40194025
if (Pred == ICmpInst::ICMP_EQ) {
4020-
if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
4026+
if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
4027+
/* AllowRefinement */ false, MaxRecurse) ==
40214028
TrueVal ||
4022-
SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
4029+
SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
4030+
/* AllowRefinement */ false, MaxRecurse) ==
40234031
TrueVal)
40244032
return FalseVal;
4025-
if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
4033+
if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
4034+
/* AllowRefinement */ true, MaxRecurse) ==
40264035
FalseVal ||
4027-
SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
4036+
SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
4037+
/* AllowRefinement */ true, MaxRecurse) ==
40284038
FalseVal)
40294039
return FalseVal;
40304040
}

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,22 +1148,6 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp,
11481148
return &Sel;
11491149
}
11501150

1151-
static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp,
1152-
const SimplifyQuery &Q) {
1153-
// If this is a binary operator, try to simplify it with the replaced op
1154-
// because we know Op and ReplaceOp are equivalant.
1155-
// For example: V = X + 1, Op = X, ReplaceOp = 42
1156-
// Simplifies as: add(42, 1) --> 43
1157-
if (auto *BO = dyn_cast<BinaryOperator>(V)) {
1158-
if (BO->getOperand(0) == Op)
1159-
return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q);
1160-
if (BO->getOperand(1) == Op)
1161-
return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q);
1162-
}
1163-
1164-
return nullptr;
1165-
}
1166-
11671151
/// If we have a select with an equality comparison, then we know the value in
11681152
/// one of the arms of the select. See if substituting this value into an arm
11691153
/// and simplifying the result yields the same value as the other arm.
@@ -1190,20 +1174,45 @@ static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
11901174
if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
11911175
std::swap(TrueVal, FalseVal);
11921176

1177+
auto *FalseInst = dyn_cast<Instruction>(FalseVal);
1178+
if (!FalseInst)
1179+
return nullptr;
1180+
1181+
// InstSimplify already performed this fold if it was possible subject to
1182+
// current poison-generating flags. Try the transform again with
1183+
// poison-generating flags temporarily dropped.
1184+
bool WasNUW = false, WasNSW = false, WasExact = false;
1185+
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
1186+
WasNUW = OBO->hasNoUnsignedWrap();
1187+
WasNSW = OBO->hasNoSignedWrap();
1188+
FalseInst->setHasNoUnsignedWrap(false);
1189+
FalseInst->setHasNoSignedWrap(false);
1190+
}
1191+
if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
1192+
WasExact = PEO->isExact();
1193+
FalseInst->setIsExact(false);
1194+
}
1195+
11931196
// Try each equivalence substitution possibility.
11941197
// We have an 'EQ' comparison, so the select's false value will propagate.
11951198
// Example:
11961199
// (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
1197-
// (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43
11981200
Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
1199-
if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal ||
1200-
simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal ||
1201-
simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal ||
1202-
simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) {
1203-
if (auto *FalseInst = dyn_cast<Instruction>(FalseVal))
1204-
FalseInst->dropPoisonGeneratingFlags();
1201+
if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
1202+
/* AllowRefinement */ false) == TrueVal ||
1203+
SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
1204+
/* AllowRefinement */ false) == TrueVal) {
12051205
return FalseVal;
12061206
}
1207+
1208+
// Restore poison-generating flags if the transform did not apply.
1209+
if (WasNUW)
1210+
FalseInst->setHasNoUnsignedWrap();
1211+
if (WasNSW)
1212+
FalseInst->setHasNoSignedWrap();
1213+
if (WasExact)
1214+
FalseInst->setIsExact();
1215+
12071216
return nullptr;
12081217
}
12091218

llvm/test/Transforms/InstCombine/select.ll

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,9 +2491,11 @@ define <2 x i32> @true_undef_vec(i1 %cond, <2 x i32> %x) {
24912491
; FIXME: This is a miscompile!
24922492
define i32 @pr47322_more_poisonous_replacement(i32 %arg) {
24932493
; CHECK-LABEL: @pr47322_more_poisonous_replacement(
2494-
; CHECK-NEXT: [[TRAILING:%.*]] = call i32 @llvm.cttz.i32(i32 [[ARG:%.*]], i1 immarg true), [[RNG0:!range !.*]]
2494+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ARG:%.*]], 0
2495+
; CHECK-NEXT: [[TRAILING:%.*]] = call i32 @llvm.cttz.i32(i32 [[ARG]], i1 immarg true), [[RNG0:!range !.*]]
24952496
; CHECK-NEXT: [[SHIFTED:%.*]] = lshr i32 [[ARG]], [[TRAILING]]
2496-
; CHECK-NEXT: ret i32 [[SHIFTED]]
2497+
; CHECK-NEXT: [[R1_SROA_0_1:%.*]] = select i1 [[CMP]], i32 0, i32 [[SHIFTED]]
2498+
; CHECK-NEXT: ret i32 [[R1_SROA_0_1]]
24972499
;
24982500
%cmp = icmp eq i32 %arg, 0
24992501
%trailing = call i32 @llvm.cttz.i32(i32 %arg, i1 immarg true)

0 commit comments

Comments
 (0)