-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[InstCombine] Add missing patterns for scmp and ucmp #149225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: AZero13 (AZero13) ChangesFixes: #146178 Full diff: https://github.com/llvm/llvm-project/pull/149225.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 73ba0f78e8053..f483631f14076 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3635,6 +3635,12 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder,
// (x < y) ? -1 : zext(x > y)
// (x > y) ? 1 : sext(x != y)
// (x > y) ? 1 : sext(x < y)
+// (x == y) ? 0 : (x > y ? 1 : -1)
+// (x == y) ? 0 : (x < y ? -1 : 1)
+// Special cases: x == C ? 0 : (x > C - 1 ? 1 : -1) and
+// Special cases: x == C ? 0 : (x < C - 1 ? -1 : 1) and
+// Special cases: x == C ? 0 : (x > C + 1 ? 1 : -1) and
+// Special cases: x == C ? 0 : (x < C + 1 ? -1 : 1)
// Into ucmp/scmp(x, y), where signedness is determined by the signedness
// of the comparison in the original sequence.
Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
@@ -3680,10 +3686,12 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
Pred = ICmpInst::getSwappedPredicate(Pred);
std::swap(LHS, RHS);
}
+
bool IsSigned = ICmpInst::isSigned(Pred);
bool Replace = false;
CmpPredicate ExtendedCmpPredicate;
+
// (x < y) ? -1 : zext(x != y)
// (x < y) ? -1 : zext(x > y)
if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
@@ -3703,34 +3711,134 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
Replace = true;
// (x == y) ? 0 : (x > y ? 1 : -1)
+ // (x == y) ? 0 : (x < y ? -1 : 1)
CmpPredicate FalseBranchSelectPredicate;
const APInt *InnerTV, *InnerFV;
if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero()) &&
match(FV, m_Select(m_c_ICmp(FalseBranchSelectPredicate, m_Specific(LHS),
m_Specific(RHS)),
m_APInt(InnerTV), m_APInt(InnerFV)))) {
- if (!ICmpInst::isGT(FalseBranchSelectPredicate)) {
- FalseBranchSelectPredicate =
- ICmpInst::getSwappedPredicate(FalseBranchSelectPredicate);
- std::swap(LHS, RHS);
+
+ // Check if we need to canonicalize the comparison predicate
+ bool PredicateSwapped = false;
+ if (!ICmpInst::isGT(FalseBranchSelectPredicate) && !ICmpInst::isLT(FalseBranchSelectPredicate)) {
+ // Not a GT or LT, nothing to do
+ } else if (!ICmpInst::isGT(FalseBranchSelectPredicate)) {
+ // We have LT, see if swapping gives us GT
+ CmpPredicate SwappedPred = ICmpInst::getSwappedPredicate(FalseBranchSelectPredicate);
+ if (ICmpInst::isGT(SwappedPred)) {
+ FalseBranchSelectPredicate = SwappedPred;
+ PredicateSwapped = true;
+ }
}
- if (!InnerTV->isOne()) {
+ // Check if we need to canonicalize the select values
+ bool ValuesSwapped = false;
+ if (!InnerTV->isOne() && InnerFV->isOne()) {
std::swap(InnerTV, InnerFV);
- std::swap(LHS, RHS);
+ ValuesSwapped = true;
}
+ // Handle (x == y) ? 0 : (x > y ? 1 : -1) or its equivalent forms
if (ICmpInst::isGT(FalseBranchSelectPredicate) && InnerTV->isOne() &&
InnerFV->isAllOnes()) {
IsSigned = ICmpInst::isSigned(FalseBranchSelectPredicate);
+ // If we swapped the predicate XOR swapped the values, we need to swap LHS/RHS for scmp
+ if (PredicateSwapped != ValuesSwapped) {
+ std::swap(LHS, RHS);
+ }
+ Replace = true;
+ }
+ // Handle (x == y) ? 0 : (x < y ? -1 : 1) or its equivalent forms
+ else if (ICmpInst::isLT(FalseBranchSelectPredicate) && InnerTV->isAllOnes() &&
+ InnerFV->isOne()) {
+ IsSigned = ICmpInst::isSigned(FalseBranchSelectPredicate);
+ // For LT pattern, operand order is already correct
Replace = true;
}
}
+
+
+ // Special cases: x == C ? 0 : (x > C-1 ? 1 : -1), etc.
+ if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero())) {
+ Value *X;
+ const APInt *C;
+ if (match(LHS, m_Value(X)) && match(RHS, m_APInt(C))) {
+
+ // Match the nested select - no canonicalization, match each pattern
+ // directly
+ CmpPredicate InnerPred;
+ Value *InnerLHS, *InnerRHS;
+ const APInt *InnerTV, *InnerFV;
+ if (match(FV, m_Select(
+ m_ICmp(InnerPred, m_Value(InnerLHS), m_Value(InnerRHS)),
+ m_APInt(InnerTV), m_APInt(InnerFV)))) {
+
+ // x == C ? 0 : (x > C-1 ? 1 : -1)
+ if (ICmpInst::isGT(InnerPred) && InnerTV->isOne() &&
+ InnerFV->isAllOnes()) {
+ IsSigned = ICmpInst::isSigned(InnerPred);
+ bool CanSubOne = IsSigned ? !C->isMinSignedValue() : !C->isMinValue();
+ if (CanSubOne) {
+ APInt Cminus1 = *C - 1;
+ if ((InnerLHS == X && match(InnerRHS, m_SpecificInt(Cminus1))) ||
+ (InnerRHS == X && match(InnerLHS, m_SpecificInt(Cminus1)))) {
+ Replace = true;
+ }
+ }
+ }
+
+ // x == C ? 0 : (x < C-1 ? -1 : 1)
+ if (ICmpInst::isLT(InnerPred) && InnerTV->isAllOnes() &&
+ InnerFV->isOne()) {
+ IsSigned = ICmpInst::isSigned(InnerPred);
+ bool CanSubOne = IsSigned ? !C->isMinSignedValue() : !C->isMinValue();
+ if (CanSubOne) {
+ APInt Cminus1 = *C - 1;
+ if ((InnerLHS == X && match(InnerRHS, m_SpecificInt(Cminus1))) ||
+ (InnerRHS == X && match(InnerLHS, m_SpecificInt(Cminus1)))) {
+ Replace = true;
+ }
+ }
+ }
+
+ // x == C ? 0 : (x > C+1 ? 1 : -1)
+ if (ICmpInst::isGT(InnerPred) && InnerTV->isOne() &&
+ InnerFV->isAllOnes()) {
+ IsSigned = ICmpInst::isSigned(InnerPred);
+ bool CanAddOne = IsSigned ? !C->isMaxSignedValue() : !C->isMaxValue();
+ if (CanAddOne) {
+ APInt Cplus1 = *C + 1;
+ if ((InnerLHS == X && match(InnerRHS, m_SpecificInt(Cplus1))) ||
+ (InnerRHS == X && match(InnerLHS, m_SpecificInt(Cplus1)))) {
+ Replace = true;
+ }
+ }
+ }
+
+ // x == C ? 0 : (x < C+1 ? -1 : 1)
+ if (ICmpInst::isLT(InnerPred) && InnerTV->isAllOnes() &&
+ InnerFV->isOne()) {
+ IsSigned = ICmpInst::isSigned(InnerPred);
+ bool CanAddOne = IsSigned ? !C->isMaxSignedValue() : !C->isMaxValue();
+ if (CanAddOne) {
+ APInt Cplus1 = *C + 1;
+ if ((InnerLHS == X && match(InnerRHS, m_SpecificInt(Cplus1))) ||
+ (InnerRHS == X && match(InnerLHS, m_SpecificInt(Cplus1)))) {
+ Replace = true;
+ }
+ }
+ }
+ }
+ }
+ }
+
Intrinsic::ID IID = IsSigned ? Intrinsic::scmp : Intrinsic::ucmp;
if (Replace)
return replaceInstUsesWith(
SI, Builder.CreateIntrinsic(SI.getType(), IID, {LHS, RHS}));
+
return nullptr;
}
@@ -4496,5 +4604,21 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
return replaceOperand(SI, 2, ConstantInt::get(FalseVal->getType(), 0));
}
+ // Canonicalize sign function ashr pattern: select (icmp slt X, 1), ashr X, bitwidth-1, 1 -> scmp(X, 0)
+ Value *X;
+ unsigned BitWidth = SI.getType()->getScalarSizeInBits();
+ CmpPredicate Pred;
+ if (match(&SI,
+ m_Select(
+ m_ICmp(Pred, m_Value(X), m_One()),
+ m_AShr(m_Deferred(X), m_SpecificInt(BitWidth - 1)),
+ m_One())) &&
+ Pred == ICmpInst::ICMP_SLT) {
+
+ Function *Scmp = Intrinsic::getOrInsertDeclaration(
+ SI.getModule(), Intrinsic::scmp, {SI.getType(), SI.getType()});
+ return CallInst::Create(Scmp, {X, ConstantInt::get(SI.getType(), 0)});
+ }
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/scmp.ll b/llvm/test/Transforms/InstCombine/scmp.ll
index 2bf22aeb7a6e9..218e2f004acc4 100644
--- a/llvm/test/Transforms/InstCombine/scmp.ll
+++ b/llvm/test/Transforms/InstCombine/scmp.ll
@@ -436,6 +436,139 @@ define <3 x i2> @scmp_unary_shuffle_ops(<3 x i8> %x, <3 x i8> %y) {
ret <3 x i2> %r
}
+define i32 @scmp_ashr(i32 %a) {
+; CHECK-LABEL: define i32 @scmp_ashr(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 0)
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %a.lobit = ashr i32 %a, 31
+ %cmp.inv = icmp slt i32 %a, 1
+ %retval.0 = select i1 %cmp.inv, i32 %a.lobit, i32 1
+ ret i32 %retval.0
+}
+
+define i32 @scmp_sgt_slt(i32 %a) {
+; CHECK-LABEL: define i32 @scmp_sgt_slt(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 0)
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp sgt i32 %a, 0
+ %cmp1 = icmp slt i32 %a, 0
+ %. = select i1 %cmp1, i32 -1, i32 0
+ %retval.0 = select i1 %cmp, i32 1, i32 %.
+ ret i32 %retval.0
+}
+
+define i32 @scmp_zero_slt(i32 %a) {
+; CHECK-LABEL: define i32 @scmp_zero_slt(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 0)
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp eq i32 %a, 0
+ %cmp1.inv = icmp slt i32 %a, 1
+ %. = select i1 %cmp1.inv, i32 -1, i32 1
+ %retval.0 = select i1 %cmp, i32 0, i32 %.
+ ret i32 %retval.0
+}
+
+define i32 @scmp_zero_sgt(i32 %a) {
+; CHECK-LABEL: define i32 @scmp_zero_sgt(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 0)
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp eq i32 %a, 0
+ %cmp1.inv = icmp sgt i32 %a, -1
+ %. = select i1 %cmp1.inv, i32 1, i32 -1
+ %retval.0 = select i1 %cmp, i32 0, i32 %.
+ ret i32 %retval.0
+}
+
+
+define i32 @ucmp_sgt_slt_neg(i32 %a) {
+; CHECK-LABEL: define i32 @ucmp_sgt_slt_neg(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp ne i32 [[A]], 0
+; CHECK-NEXT: [[RETVAL_0:%.*]] = zext i1 [[CMP_NOT]] to i32
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp ugt i32 %a, 0
+ %cmp1 = icmp ult i32 %a, 0
+ %. = select i1 %cmp1, i32 -1, i32 0
+ %retval.0 = select i1 %cmp, i32 1, i32 %.
+ ret i32 %retval.0
+}
+
+define i32 @ucmp_zero_slt_neg(i32 %a) {
+; CHECK-LABEL: define i32 @ucmp_zero_slt_neg(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[A]], 0
+; CHECK-NEXT: [[RETVAL_0:%.*]] = zext i1 [[CMP]] to i32
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp eq i32 %a, 0
+ %cmp1.inv = icmp ult i32 %a, 1
+ %. = select i1 %cmp1.inv, i32 -1, i32 1
+ %retval.0 = select i1 %cmp, i32 0, i32 %.
+ ret i32 %retval.0
+}
+
+define i32 @ucmp_zero_sgt_neg(i32 %a) {
+; CHECK-LABEL: define i32 @ucmp_zero_sgt_neg(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[A]], 0
+; CHECK-NEXT: [[RETVAL_0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp eq i32 %a, 0
+ %cmp1.inv = icmp ugt i32 %a, -1
+ %. = select i1 %cmp1.inv, i32 1, i32 -1
+ %retval.0 = select i1 %cmp, i32 0, i32 %.
+ ret i32 %retval.0
+}
+
+define i32 @scmp_sgt_slt_ab(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @scmp_sgt_slt_ab(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp sgt i32 %a, %b
+ %cmp1 = icmp slt i32 %a, %b
+ %. = select i1 %cmp1, i32 -1, i32 0
+ %retval.0 = select i1 %cmp, i32 1, i32 %.
+ ret i32 %retval.0
+}
+
+define i32 @scmp_zero_slt_ab(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @scmp_zero_slt_ab(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp eq i32 %a, %b
+ %cmp1.inv = icmp slt i32 %a, %b
+ %. = select i1 %cmp1.inv, i32 -1, i32 1
+ %retval.0 = select i1 %cmp, i32 0, i32 %.
+ ret i32 %retval.0
+}
+
+define i32 @scmp_zero_sgt_ab(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @scmp_zero_sgt_ab(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[RETVAL_0:%.*]] = call i32 @llvm.scmp.i32.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT: ret i32 [[RETVAL_0]]
+;
+ %cmp = icmp eq i32 %a, %b
+ %cmp1.inv = icmp sgt i32 %a, %b
+ %. = select i1 %cmp1.inv, i32 1, i32 -1
+ %retval.0 = select i1 %cmp, i32 0, i32 %.
+ ret i32 %retval.0
+}
+
; Negative test: true value of outer select is not zero
define i8 @scmp_from_select_eq_and_gt_neg1(i32 %x, i32 %y) {
; CHECK-LABEL: define i8 @scmp_from_select_eq_and_gt_neg1(
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
100f431
to
d4a2a2b
Compare
@dtcxzyw Can we please merge this now? |
Align(MaskedLoadAlignment->getZExtValue()), | ||
CondVal, FalseVal)); | ||
|
||
// Canonicalize sign function ashr pattern: select (icmp slt X, 1), ashr X, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please file a separate patch for this change.
We can also fold the commuted pattern:
select (icmp sgt X, 0), 1, (ashr X, BW - 1) -> scmp(X, 0)
They fold fine @dtcxzyw |
bbd90a3
to
11f022b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
2b577ed
to
0da344e
Compare
@dtcxzyw Fixed! |
Fixes: #146178
https://alive2.llvm.org/ce/z/ZitMnX
https://alive2.llvm.org/ce/z/aJZ2BQ