-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[InstCombine] Fold (x < 2^32) & (trunc(x to i32) == 0) into x == 0
#171195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Tirthankar Mazumder (wermos) ChangesAddresses #170020. I'm not exactly sure what kind of Alive2 proof is required when the optimization has to do with I followed the suggestion given here: To do this, I had to make I'm also not sure if more tests are required or not. Full diff: https://github.com/llvm/llvm-project/pull/171195.diff 4 Files Affected:
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index b730a36488780..48cc85e719421 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -102,6 +102,15 @@ LLVM_ABI void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
const SimplifyQuery &Q,
unsigned Depth = 0);
+/// Update \p Known with bits of \p V that are implied by \p Cmp.
+/// Comparisons involving `trunc V` are handled specially: known
+/// bits are computed for the truncated value and then extended to the bitwidth
+/// of \p V.
+LLVM_ABI void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
+ KnownBits &Known,
+ const SimplifyQuery &SQ,
+ bool Invert);
+
/// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or).
LLVM_ABI KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
const KnownBits &KnownLHS,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9cb6f19b9340c..5ab5f8cfccc7f 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -968,9 +968,9 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
}
}
-static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
- KnownBits &Known,
- const SimplifyQuery &SQ, bool Invert) {
+void llvm::computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
+ KnownBits &Known,
+ const SimplifyQuery &SQ, bool Invert) {
ICmpInst::Predicate Pred =
Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
Value *LHS = Cmp->getOperand(0);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ba5568b00441b..fa7c66d736c28 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -15,11 +15,13 @@
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/FloatingPointPredicateUtils.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -3376,9 +3378,13 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
+ // dbgs() << "LHS0 = " << *LHS0 << "\nLHS1 = " << *LHS1 << '\n';
+ // dbgs() << "RHS0 = " << *RHS0 << "\nRHS1 = " << *RHS1 << '\n';
+
const APInt *LHSC = nullptr, *RHSC = nullptr;
match(LHS1, m_APInt(LHSC));
match(RHS1, m_APInt(RHSC));
+ // dbgs() << "LHSC = " << *LHSC << "\nRHSC = " << *RHSC << '\n';
// (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B)
// (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
@@ -3575,6 +3581,40 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
return Builder.createIsFPClass(X, IsAnd ? FPClassTest::fcNormal
: ~FPClassTest::fcNormal);
+ if (!IsLogical && IsAnd) {
+ auto TryCandidate = [&](Value *X) -> Value * {
+ if (!X->getType()->isIntegerTy())
+ return nullptr;
+
+ Type *Ty = X->getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ // KnownL and KnownR hold information deduced from the LHS icmp and RHS
+ // icmps, respectively
+ KnownBits KnownL(BitWidth), KnownR(BitWidth);
+
+ computeKnownBitsFromICmpCond(X, LHS, KnownL, Q, /*Invert=*/false);
+ computeKnownBitsFromICmpCond(X, RHS, KnownR, Q, /*Invert=*/false);
+
+ KnownBits Combined = KnownL.unionWith(KnownR);
+
+ // Avoid stomping on cases where one icmp alone determines X. Those are handled by more specific InstCombine folds.
+ if (KnownL.isConstant() || KnownR.isConstant())
+ return nullptr;
+
+ if (!Combined.isConstant())
+ return nullptr;
+
+ APInt ConstVal = Combined.getConstant();
+ return Builder.CreateICmpEQ(X, ConstantInt::get(Ty, ConstVal));
+ };
+
+ if (Value *Res = TryCandidate(LHS0))
+ return Res;
+ if (Value *Res = TryCandidate(RHS0))
+ return Res;
+ }
+
return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd);
}
diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
index 290e344acb980..9d69fadfa9627 100644
--- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll
+++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
@@ -702,9 +702,9 @@ define i1 @PR42691_10_logical(i32 %x) {
define i1 @substitute_constant_and_eq_eq(i8 %x, i8 %y) {
; CHECK-LABEL: @substitute_constant_and_eq_eq(
-; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret i1 [[R]]
;
%c1 = icmp eq i8 %x, 42
@@ -728,9 +728,9 @@ define i1 @substitute_constant_and_eq_eq_logical(i8 %x, i8 %y) {
define i1 @substitute_constant_and_eq_eq_commute(i8 %x, i8 %y) {
; CHECK-LABEL: @substitute_constant_and_eq_eq_commute(
-; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret i1 [[R]]
;
%c1 = icmp eq i8 %x, 42
@@ -741,9 +741,9 @@ define i1 @substitute_constant_and_eq_eq_commute(i8 %x, i8 %y) {
define i1 @substitute_constant_and_eq_eq_commute_logical(i8 %x, i8 %y) {
; CHECK-LABEL: @substitute_constant_and_eq_eq_commute_logical(
-; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret i1 [[R]]
;
%c1 = icmp eq i8 %x, 42
@@ -1392,12 +1392,12 @@ define i1 @bitwise_and_bitwise_and_icmps(i8 %x, i8 %y, i8 %z) {
define i1 @bitwise_and_bitwise_and_icmps_comm1(i8 %x, i8 %y, i8 %z) {
; CHECK-LABEL: @bitwise_and_bitwise_and_icmps_comm1(
-; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[Y:%.*]], 42
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[Y:%.*]], 42
; CHECK-NEXT: [[Z_SHIFT:%.*]] = shl nuw i8 1, [[Z:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = or i8 [[Z_SHIFT]], 1
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[X:%.*]], [[TMP1]]
-; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]]
-; CHECK-NEXT: [[AND2:%.*]] = and i1 [[C1]], [[TMP3]]
+; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]]
+; CHECK-NEXT: [[AND2:%.*]] = and i1 [[TMP3]], [[TMP4]]
; CHECK-NEXT: ret i1 [[AND2]]
;
%c1 = icmp eq i8 %y, 42
@@ -3721,3 +3721,28 @@ define i1 @merge_range_check_or(i8 %a) {
%and = or i1 %cmp1, %cmp2
ret i1 %and
}
+
+; Just a very complicated way of checking if v1 == 0.
+define i1 @complicated_zero_equality_test(i64 %v1) {
+; CHECK-LABEL: @complicated_zero_equality_test(
+; CHECK-NEXT: [[V5:%.*]] = icmp eq i64 [[V1:%.*]], 0
+; CHECK-NEXT: ret i1 [[V5]]
+;
+ %v2 = trunc i64 %v1 to i32
+ %v3 = icmp eq i32 %v2, 0
+ %v4 = icmp ult i64 %v1, 4294967296 ; 2 ^ 32
+ %v5 = and i1 %v4, %v3
+ ret i1 %v5
+}
+
+define i1 @commuted_complicated_zero_equality_test(i64 %v1) {
+; CHECK-LABEL: @commuted_complicated_zero_equality_test(
+; CHECK-NEXT: [[V5:%.*]] = icmp eq i64 [[V1:%.*]], 0
+; CHECK-NEXT: ret i1 [[V5]]
+;
+ %v2 = trunc i64 %v1 to i32
+ %v3 = icmp ult i64 %v1, 4294967296 ; 2 ^ 32
+ %v4 = icmp eq i32 %v2, 0
+ %v5 = and i1 %v4, %v3
+ ret i1 %v5
+}
|
|
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
I've addressed the email thing as well. |
|
Ping @dtcxzyw for review. |
| auto TryCandidate = [&](Value *X) -> Value * { | ||
| if (!X->getType()->isIntegerTy()) | ||
| return nullptr; | ||
|
|
||
| Type *Ty = X->getType(); | ||
| unsigned BitWidth = Ty->getScalarSizeInBits(); | ||
|
|
||
| // KnownL and KnownR hold information deduced from the LHS icmp and RHS | ||
| // icmps, respectively | ||
| KnownBits KnownL(BitWidth), KnownR(BitWidth); | ||
|
|
||
| computeKnownBitsFromICmpCond(X, LHS, KnownL, Q, /*Invert=*/false); | ||
| computeKnownBitsFromICmpCond(X, RHS, KnownR, Q, /*Invert=*/false); | ||
|
|
||
| KnownBits Combined = KnownL.unionWith(KnownR); | ||
|
|
||
| // Avoid stomping on cases where one icmp alone determines X. Those are | ||
| // handled by more specific InstCombine folds. | ||
| if (KnownL.isConstant() || KnownR.isConstant()) | ||
| return nullptr; | ||
|
|
||
| if (!Combined.isConstant()) | ||
| return nullptr; | ||
|
|
||
| APInt ConstVal = Combined.getConstant(); | ||
| return Builder.CreateICmpEQ(X, ConstantInt::get(Ty, ConstVal)); | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this part fine as a lambda, or should it be its own helper function? If it should be its own helper function, then what should it be called?
|
if the trunc is repalced by an and this is already folded see https://alive2.llvm.org/ce/z/Whfa65 llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp Lines 204 to 212 in f803e46
|
You could use this format. |
Oh yes, we can simply handle this in |
|
Alright, I'll work on moving the patch you shared into |
Resolves #170020.
I'm not exactly sure what kind of Alive2 proof is required when the optimization has to do with
KnownBitsstuff, so I'm copying over the Alive2 proof for the specific case discussed in the issue: https://alive2.llvm.org/ce/z/K59kAtI followed the suggestion given here:
To do this, I had to make
computeKnownBitsFromICmpConda part of theValueTracking.hheader.I'm also not sure if more tests are required or not.