-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[ValueTracking] Refine known bits for linear interpolation patterns #166378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ValueTracking] Refine known bits for linear interpolation patterns #166378
Conversation
|
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-llvm-transforms Author: Valeriy Savchenko (SavchenkoValeriy) ChangesIn this patch, we try to detect the lerp pattern: a * (b - c) + c * d In that particular case, we can use the following chain of reasoning: a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d) Since that is true for arbitrary a, b, c and d within our constraints, we can max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U Considering that any result of the lerp would be less or equal to U, it would While being quite a specific situation, it is fairly common in computer In conjunction with #165877, increases vectorization factor for lerp loops. Full diff: https://github.com/llvm/llvm-project/pull/166378.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0a72076f51824..4c74710065371 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -350,6 +350,140 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
return V->getType()->getScalarSizeInBits() - SignBits + 1;
}
+// Try to detect the lerp pattern: a * (b - c) + c * d
+// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
+//
+// In that particular case, we can use the following chain of reasoning:
+//
+// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
+//
+// Since that is true for arbitrary a, b, c and d within our constraints, we can
+// conclude that:
+//
+// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
+//
+// Considering that any result of the lerp would be less or equal to U, it would
+// have at least the number of leading 0s as in U.
+//
+// While being quite a specific situation, it is fairly common in computer
+// graphics in the shape of alpha blending.
+//
+// Returns unknown bits if the pattern doesn't match or constraints don't apply
+// to the given operands.
+static KnownBits computeKnownBitsFromLerpPattern(const Value *Op0,
+ const Value *Op1,
+ const APInt &DemandedElts,
+ const SimplifyQuery &Q,
+ unsigned Depth) {
+
+ Type *Ty = Op0->getType();
+ const unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ KnownBits Result(BitWidth);
+
+ // Only handle scalar types for now
+ if (Ty->isVectorTy())
+ return Result;
+
+ // Try to match: a * (b - c) + c * d.
+ // When a == 1 => A == nullptr, the same applies to d/D as well.
+ const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
+
+ const auto MatchSubBC = [&]() {
+ // (b - c) can have two forms that interest us:
+ //
+ // 1. sub nuw %b, %c
+ // 2. xor %c, %b
+ //
+ // For the first case, nuw flag guarantees our requirement b >= c.
+ //
+ // The second case happens when the analysis can infer that b is a mask for
+ // c and we can transform sub operation into xor (that is usually true for
+ // constant b's). Even though xor is symmetrical, canonicalization ensures
+ // that the constant will be the RHS. xor of two positive integers is
+ // guaranteed to be non-negative as well.
+ return m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
+ m_Xor(m_Value(C), m_Value(B)));
+ };
+
+ const auto MatchASubBC = [&]() {
+ // Cases:
+ // - a * (b - c)
+ // - (b - c) * a
+ // - (b - c) <- a implicitly equals 1
+ return m_CombineOr(m_CombineOr(m_Mul(m_Value(A), MatchSubBC()),
+ m_Mul(MatchSubBC(), m_Value(A))),
+ MatchSubBC());
+ };
+
+ const auto MatchCD = [&]() {
+ // Cases:
+ // - d * c
+ // - c * d
+ // - c <- d implicitly equals 1
+ return m_CombineOr(m_CombineOr(m_Mul(m_Value(D), m_Specific(C)),
+ m_Mul(m_Specific(C), m_Value(D))),
+ m_Specific(C));
+ };
+
+ const auto Match = [&](const Value *LHS, const Value *RHS) {
+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
+ // has to evaluate first and return true.
+ //
+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
+ return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
+ };
+
+ if (!Match(Op0, Op1) && !Match(Op1, Op0))
+ return Result;
+
+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
+ // For some of the values we use the convention of leaving
+ // it nullptr to signify an implicit constant 1.
+ return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
+ : KnownBits::makeConstant(APInt(BitWidth, 1));
+ };
+
+ // Check that all operands are non-negative
+ const KnownBits KnownA = ComputeKnownBitsOrOne(A);
+ if (!KnownA.isNonNegative())
+ return Result;
+
+ const KnownBits KnownD = ComputeKnownBitsOrOne(D);
+ if (!KnownD.isNonNegative())
+ return Result;
+
+ const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
+ if (!KnownB.isNonNegative())
+ return Result;
+
+ const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
+ if (!KnownC.isNonNegative())
+ return Result;
+
+ // Compute max(a, d)
+ const APInt MaxA = KnownA.getMaxValue();
+ const APInt MaxD = KnownD.getMaxValue();
+ const APInt MaxAD = MaxA.ult(MaxD) ? MaxD : MaxA;
+
+ // Compute max(a, d) * max(b)
+ const APInt MaxB = KnownB.getMaxValue();
+ bool Overflow;
+ const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow);
+
+ if (Overflow)
+ return Result;
+
+ // Count leading zeros in upper bound
+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
+
+ // Create KnownBits with only leading zeros set
+ Result.Zero.setHighBits(MinimumNumberOfLeadingZeros);
+
+ return Result;
+}
+
static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
bool NSW, bool NUW,
const APInt &DemandedElts,
@@ -369,6 +503,15 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
.value_or(false))
KnownOut.makeNonNegative();
+
+ if (Add) {
+ // Try to match lerp pattern and combine results
+ const KnownBits LerpKnown =
+ computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, Q, Depth);
+ // Union of any two conservative estimates results in a conservative
+ // estimate that is at least as precise as each individual estimate.
+ KnownOut = KnownOut.unionWith(LerpKnown);
+ }
}
static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
diff --git a/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll b/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll
new file mode 100644
index 0000000000000..3018d3e99f636
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll
@@ -0,0 +1,156 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+; Test known bits refinements for pattern: a * (b - c) + c * d
+; where a > 0, c > 0, b > 0, d > 0, and b > c.
+; This pattern is a generalization of lerp and it appears frequently in graphics operations.
+
+define i32 @test_clamp(i8 %a, i8 %c, i8 %d) {
+; CHECK-LABEL: define i32 @test_clamp(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: ret i32 [[ADD]]
+;
+ %a32 = zext i8 %a to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = sub i32 255, %c32
+ %mul1 = mul i32 %a32, %sub
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %mul1, %mul2
+ %cmp = icmp ugt i32 %add, 65535
+ %result = select i1 %cmp, i32 65535, i32 %add
+ ret i32 %result
+}
+
+define i1 @test_trunc_cmp(i8 %a, i8 %c, i8 %d) {
+; CHECK-LABEL: define i1 @test_trunc_cmp(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %a32 = zext i8 %a to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = sub i32 255, %c32
+ %mul1 = mul i32 %a32, %sub
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %mul1, %mul2
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+define i1 @test_trunc_cmp_xor(i8 %a, i8 %c, i8 %d) {
+; CHECK-LABEL: define i1 @test_trunc_cmp_xor(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %a32 = zext i8 %a to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = xor i32 255, %c32
+ %mul1 = mul i32 %a32, %sub
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %mul1, %mul2
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+define i1 @test_trunc_cmp_arbitrary_b(i8 %a, i8 %b, i8 %c, i8 %d) {
+; CHECK-LABEL: define i1 @test_trunc_cmp_arbitrary_b(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %a32 = zext i8 %a to i32
+ %b32 = zext i8 %b to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = sub nsw nuw i32 %b32, %c32
+ %mul1 = mul i32 %a32, %sub
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %mul1, %mul2
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+
+define i1 @test_trunc_cmp_no_a(i8 %b, i8 %c, i8 %d) {
+; CHECK-LABEL: define i1 @test_trunc_cmp_no_a(
+; CHECK-SAME: i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
+; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
+; CHECK-NEXT: [[MUL1:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %b32 = zext i8 %b to i32
+ %c32 = zext i8 %c to i32
+ %d32 = zext i8 %d to i32
+ %sub = sub nuw i32 %b32, %c32
+ %mul2 = mul i32 %c32, %d32
+ %add = add i32 %sub, %mul2
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+define i1 @test_trunc_cmp_no_d(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i1 @test_trunc_cmp_no_d(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
+; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
+; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
+; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
+; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
+; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[C32]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %a32 = zext i8 %a to i32
+ %b32 = zext i8 %b to i32
+ %c32 = zext i8 %c to i32
+ %sub = sub nsw nuw i32 %b32, %c32
+ %mul1 = mul i32 %a32, %sub
+ %add = add i32 %mul1, %c32
+ %trunc = trunc i32 %add to i16
+ %cmp = icmp eq i16 %trunc, 1234
+ ret i1 %cmp
+}
+
+declare void @llvm.assume(i1)
|
fhahn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since that is true for arbitrary a, b, c and d within our constraints, we can
conclude that:max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
Would it be possible to add an Alive2 proof for the conclusion?
| const auto MatchSubBC = [&]() { | ||
| // (b - c) can have two forms that interest us: | ||
| // | ||
| // 1. sub nuw %b, %c | ||
| // 2. xor %c, %b | ||
| // | ||
| // For the first case, nuw flag guarantees our requirement b >= c. | ||
| // | ||
| // The second case happens when the analysis can infer that b is a mask for | ||
| // c and we can transform sub operation into xor (that is usually true for | ||
| // constant b's). Even though xor is symmetrical, canonicalization ensures | ||
| // that the constant will be the RHS. xor of two positive integers is | ||
| // guaranteed to be non-negative as well. | ||
| return m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)), | ||
| m_Xor(m_Value(C), m_Value(B))); | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make the review easier if the initial patch would just add support for one combination, then extend to support the addiitonal combinations as follow-up.
If one pattern is more common than the others in practice, that would be a good candidate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I agree with this sentiment about the patches in general, I believe that this patch is quite small (effectively ~50 LoC) and general. Ripping it apart to support individual sub-cases feels like going out of the way just to follow that principle.
For example, subtraction vs exclusive or will affect literally this line alone and remove just that line of code.
If you insist, of course I can do it, but I seriously don't see a good reason for that especially considering that the proof works in a more general case.
That's a great point and I believe that such statements should be proven. However, I'd disagree that Alive2 is the best tool for this particular change. Alive2 proves individual transformations and I will provide Alive2 proofs for the tests that I added. Since I added a more general known bits calculation that in theory can work well with one transformation but not the other, I wrote a z3 proof to show that inferred known bits are correct: https://gist.github.com/SavchenkoValeriy/a7d3e40868b9d512a70260ae63d8413b |
|
Alive2 struggles (times out) with 32 bit vectors for this particular proof: https://alive2.llvm.org/ce/z/S-78k- |
Please confirm if I shrunk the inputs correctly: https://alive2.llvm.org/ce/z/AABL4v (I guess you missed poison-generating flags in the source, as otherwise your target function is defined for a less pairs of inputs than the source.) |
In your example, we can always infer that the result is < 65535 just because we are extending i8 to i32. The correct shrinking would include extending N / 4 to size N, but actually we can infer that the result is < 2^(N/2). Like this: https://alive2.llvm.org/ce/z/5A5mCg |
Oh, you're right, missed that the sum fits N/2; so inferring nsw/nuw should make sense as the result will never overflow: https://alive2.llvm.org/ce/z/_rAdQg. |
Exactly! |
dtcxzyw
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please provide a generalized alive2 proof?
See https://llvm.org/docs/InstCombineContributorGuide.html#proofs.
llvm/lib/Analysis/ValueTracking.cpp
Outdated
| // that the constant will be the RHS. xor of two positive integers is | ||
| // guaranteed to be non-negative as well. | ||
| return m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)), | ||
| m_Xor(m_Value(C), m_Value(B))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing check for the constraint C <=u B.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What we actually want to ensure is that (b - c) doesn't become negative or more generally doesn't wrap. For sub, it's ensured via the b >= c constraint or the nuw flag. For xor, this comes out of the box since xor doesn't wrap.
That's why I have the comment above stating that xor of two positive integers is non-negative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For xor, this comes out of the box since xor doesn't wrap.
But you cannot refine xor c, b into sub nuw c, b by assuming that xor c, b is introduced as the canonical form of sub nuw Mask, b. A counterexample: https://alive2.llvm.org/ce/z/sWwAPJ
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a great catch, thank you!
As a matter of fact, we can't even make this conclusion for xor and a constraint b >= c: https://gist.github.com/SavchenkoValeriy/fff5f7d58daf8f7d2c5b2eb955fbbf89
We have to match exactly the situation when we replace sub with xor -- when b is a mask for c. Basically, we need to check if b == ~(known zeros of c). The proof is here: https://gist.github.com/SavchenkoValeriy/6d3d0da4cc68b6d7f6a4ef97d6502c9d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the code to include the extra check with the proof for it here https://gist.github.com/SavchenkoValeriy/6d3d0da4cc68b6d7f6a4ef97d6502c9d. Alive2 proof is here https://alive2.llvm.org/ce/z/F45enJ. So, this should be fixed now.
I believe I answered it here: In short, I'm not improving any particular transformation but an analysis and analysis can't be proven via alive2. I provide a general z3 proof for the exact logic that I encoded. Please tell me if you have any questions about that proof. |
00e6f66 to
2da8c50
Compare
z3 proof is ok. But you have to handle poison stuff yourself (e.g., |
With the Alive2 proof it's also not fully true again because I need to prove the analysis and not the transformation. |
2da8c50 to
462a54f
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
600d360 to
875a2dc
Compare
875a2dc to
d0d11c2
Compare
d0d11c2 to
e23b433
Compare
dtcxzyw
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
| // ensures that the constant will be the RHS. We have additional checks | ||
| // later on to ensure that this xor operation is equivalent to subtraction. | ||
| return m_Instruction(SubBC, m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)), | ||
| m_Xor(m_Value(C), m_Value(B)))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may fail to match when both b and c are non-constants and c is a subset of b. But I think it is fair enough to handle the constant RHS case first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a really good point. I will follow up with a patch for that because we will need to change how we match quite a lot because both multiplication and xor and commutative and we can't in advance understand what is b and what is c.
e23b433 to
f7b3207
Compare
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/88/builds/17919 Here is the relevant piece of the build log for the reference |
|
☝️ this test seems to be flaky. We had exactly the same failure earlier: https://lab.llvm.org/buildbot/#/builders/88/builds/17887 and the next build after mine is successful. |
…lvm#166378) In this patch, we try to detect the lerp pattern: a * (b - c) + c * d where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c. In that particular case, we can use the following chain of reasoning: a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d) Since that is true for arbitrary a, b, c and d within our constraints, we can conclude that: max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U Considering that any result of the lerp would be less or equal to U, it would have at least the number of leading 0s as in U. While being quite a specific situation, it is fairly common in computer graphics in the shape of alpha blending. In conjunction with llvm#165877, increases vectorization factor for lerp loops.
In this patch, we try to detect the lerp pattern: a * (b - c) + c * d
where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
In that particular case, we can use the following chain of reasoning:
a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
Since that is true for arbitrary a, b, c and d within our constraints, we can
conclude that:
max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
Considering that any result of the lerp would be less or equal to U, it would
have at least the number of leading 0s as in U.
While being quite a specific situation, it is fairly common in computer
graphics in the shape of alpha blending.
In conjunction with #165877, increases vectorization factor for lerp loops.