Skip to content

Commit 00e6f66

Browse files
[KnownBits] Refine known bits for lerp
1 parent aedd1e7 commit 00e6f66

File tree

2 files changed

+149
-12
lines changed

2 files changed

+149
-12
lines changed

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,140 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
350350
return V->getType()->getScalarSizeInBits() - SignBits + 1;
351351
}
352352

353+
// Try to detect the lerp pattern: a * (b - c) + c * d
354+
// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355+
//
356+
// In that particular case, we can use the following chain of reasoning:
357+
//
358+
// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359+
//
360+
// Since that is true for arbitrary a, b, c and d within our constraints, we can
361+
// conclude that:
362+
//
363+
// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364+
//
365+
// Considering that any result of the lerp would be less or equal to U, it would
366+
// have at least the number of leading 0s as in U.
367+
//
368+
// While being quite a specific situation, it is fairly common in computer
369+
// graphics in the shape of alpha blending.
370+
//
371+
// Returns unknown bits if the pattern doesn't match or constraints don't apply
372+
// to the given operands.
373+
static KnownBits computeKnownBitsFromLerpPattern(const Value *Op0,
374+
const Value *Op1,
375+
const APInt &DemandedElts,
376+
const SimplifyQuery &Q,
377+
unsigned Depth) {
378+
379+
Type *Ty = Op0->getType();
380+
const unsigned BitWidth = Ty->getScalarSizeInBits();
381+
382+
KnownBits Result(BitWidth);
383+
384+
// Only handle scalar types for now
385+
if (Ty->isVectorTy())
386+
return Result;
387+
388+
// Try to match: a * (b - c) + c * d.
389+
// When a == 1 => A == nullptr, the same applies to d/D as well.
390+
const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
391+
392+
const auto MatchSubBC = [&]() {
393+
// (b - c) can have two forms that interest us:
394+
//
395+
// 1. sub nuw %b, %c
396+
// 2. xor %c, %b
397+
//
398+
// For the first case, nuw flag guarantees our requirement b >= c.
399+
//
400+
// The second case happens when the analysis can infer that b is a mask for
401+
// c and we can transform sub operation into xor (that is usually true for
402+
// constant b's). Even though xor is symmetrical, canonicalization ensures
403+
// that the constant will be the RHS. xor of two positive integers is
404+
// guaranteed to be non-negative as well.
405+
return m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
406+
m_Xor(m_Value(C), m_Value(B)));
407+
};
408+
409+
const auto MatchASubBC = [&]() {
410+
// Cases:
411+
// - a * (b - c)
412+
// - (b - c) * a
413+
// - (b - c) <- a implicitly equals 1
414+
return m_CombineOr(m_CombineOr(m_Mul(m_Value(A), MatchSubBC()),
415+
m_Mul(MatchSubBC(), m_Value(A))),
416+
MatchSubBC());
417+
};
418+
419+
const auto MatchCD = [&]() {
420+
// Cases:
421+
// - d * c
422+
// - c * d
423+
// - c <- d implicitly equals 1
424+
return m_CombineOr(m_CombineOr(m_Mul(m_Value(D), m_Specific(C)),
425+
m_Mul(m_Specific(C), m_Value(D))),
426+
m_Specific(C));
427+
};
428+
429+
const auto Match = [&](const Value *LHS, const Value *RHS) {
430+
// We do use m_Specific(C) in MatchCD, so we have to make sure that
431+
// it's bound to anything and match(LHS, MatchASubBC()) absolutely
432+
// has to evaluate first and return true.
433+
//
434+
// If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
435+
return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
436+
};
437+
438+
if (!Match(Op0, Op1) && !Match(Op1, Op0))
439+
return Result;
440+
441+
const auto ComputeKnownBitsOrOne = [&](const Value *V) {
442+
// For some of the values we use the convention of leaving
443+
// it nullptr to signify an implicit constant 1.
444+
return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
445+
: KnownBits::makeConstant(APInt(BitWidth, 1));
446+
};
447+
448+
// Check that all operands are non-negative
449+
const KnownBits KnownA = ComputeKnownBitsOrOne(A);
450+
if (!KnownA.isNonNegative())
451+
return Result;
452+
453+
const KnownBits KnownD = ComputeKnownBitsOrOne(D);
454+
if (!KnownD.isNonNegative())
455+
return Result;
456+
457+
const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
458+
if (!KnownB.isNonNegative())
459+
return Result;
460+
461+
const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
462+
if (!KnownC.isNonNegative())
463+
return Result;
464+
465+
// Compute max(a, d)
466+
const APInt MaxA = KnownA.getMaxValue();
467+
const APInt MaxD = KnownD.getMaxValue();
468+
const APInt MaxAD = MaxA.ult(MaxD) ? MaxD : MaxA;
469+
470+
// Compute max(a, d) * max(b)
471+
const APInt MaxB = KnownB.getMaxValue();
472+
bool Overflow;
473+
const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow);
474+
475+
if (Overflow)
476+
return Result;
477+
478+
// Count leading zeros in upper bound
479+
const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
480+
481+
// Create KnownBits with only leading zeros set
482+
Result.Zero.setHighBits(MinimumNumberOfLeadingZeros);
483+
484+
return Result;
485+
}
486+
353487
static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
354488
bool NSW, bool NUW,
355489
const APInt &DemandedElts,
@@ -369,6 +503,15 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
369503
isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
370504
.value_or(false))
371505
KnownOut.makeNonNegative();
506+
507+
if (Add) {
508+
// Try to match lerp pattern and combine results
509+
const KnownBits LerpKnown =
510+
computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, Q, Depth);
511+
// Union of any two conservative estimates results in a conservative
512+
// estimate that is at least as precise as each individual estimate.
513+
KnownOut = KnownOut.unionWith(LerpKnown);
514+
}
372515
}
373516

374517
static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,

llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ define i32 @test_clamp(i8 %a, i8 %c, i8 %d) {
1515
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
1616
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
1717
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
18-
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @llvm.umin.i32(i32 [[ADD]], i32 65535)
19-
; CHECK-NEXT: ret i32 [[RESULT]]
18+
; CHECK-NEXT: ret i32 [[ADD]]
2019
;
2120
%a32 = zext i8 %a to i32
2221
%c32 = zext i8 %c to i32
@@ -40,8 +39,7 @@ define i1 @test_trunc_cmp(i8 %a, i8 %c, i8 %d) {
4039
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
4140
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
4241
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
43-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
44-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
42+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
4543
; CHECK-NEXT: ret i1 [[CMP]]
4644
;
4745
%a32 = zext i8 %a to i32
@@ -66,8 +64,7 @@ define i1 @test_trunc_cmp_xor(i8 %a, i8 %c, i8 %d) {
6664
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
6765
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
6866
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
69-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
70-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
67+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
7168
; CHECK-NEXT: ret i1 [[CMP]]
7269
;
7370
%a32 = zext i8 %a to i32
@@ -93,8 +90,7 @@ define i1 @test_trunc_cmp_arbitrary_b(i8 %a, i8 %b, i8 %c, i8 %d) {
9390
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
9491
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
9592
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
96-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
97-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
93+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
9894
; CHECK-NEXT: ret i1 [[CMP]]
9995
;
10096
%a32 = zext i8 %a to i32
@@ -120,8 +116,7 @@ define i1 @test_trunc_cmp_no_a(i8 %b, i8 %c, i8 %d) {
120116
; CHECK-NEXT: [[MUL1:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
121117
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
122118
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
123-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
124-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
119+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
125120
; CHECK-NEXT: ret i1 [[CMP]]
126121
;
127122
%b32 = zext i8 %b to i32
@@ -144,8 +139,7 @@ define i1 @test_trunc_cmp_no_d(i8 %a, i8 %b, i8 %c) {
144139
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
145140
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
146141
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[C32]]
147-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
148-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
142+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
149143
; CHECK-NEXT: ret i1 [[CMP]]
150144
;
151145
%a32 = zext i8 %a to i32

0 commit comments

Comments
 (0)