Skip to content

Commit 600d360

Browse files
[KnownBits] Refine known bits for lerp
1 parent aedd1e7 commit 600d360

File tree

3 files changed

+184
-13
lines changed

3 files changed

+184
-13
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,8 +882,10 @@ m_Instruction(Instruction *&I, const MatchTy &Match) {
882882

883883
/// Match a unary operator, capturing it if we match.
884884
inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; }
885+
inline bind_ty<const UnaryOperator> m_UnOp(const UnaryOperator *&I) { return I; }
885886
/// Match a binary operator, capturing it if we match.
886887
inline bind_ty<BinaryOperator> m_BinOp(BinaryOperator *&I) { return I; }
888+
inline bind_ty<const BinaryOperator> m_BinOp(const BinaryOperator *&I) { return I; }
887889
/// Match a with overflow intrinsic, capturing it if we match.
888890
inline bind_ty<WithOverflowInst> m_WithOverflowInst(WithOverflowInst *&I) {
889891
return I;

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,152 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
350350
return V->getType()->getScalarSizeInBits() - SignBits + 1;
351351
}
352352

353+
/// Try to detect the lerp pattern: a * (b - c) + c * d
354+
/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355+
///
356+
/// In that particular case, we can use the following chain of reasoning:
357+
///
358+
/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359+
///
360+
/// Since that is true for arbitrary a, b, c and d within our constraints, we
361+
/// can conclude that:
362+
///
363+
/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364+
///
365+
/// Considering that any result of the lerp would be less or equal to U, it
366+
/// would have at least the number of leading 0s as in U.
367+
///
368+
/// While being quite a specific situation, it is fairly common in computer
369+
/// graphics in the shape of alpha blending.
370+
///
371+
/// Modifies given KnownOut in-place with the inferred information.
372+
static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
373+
const APInt &DemandedElts,
374+
KnownBits &KnownOut,
375+
const SimplifyQuery &Q,
376+
unsigned Depth) {
377+
378+
Type *Ty = Op0->getType();
379+
const unsigned BitWidth = Ty->getScalarSizeInBits();
380+
381+
// Only handle scalar types for now
382+
if (Ty->isVectorTy())
383+
return;
384+
385+
// Try to match: a * (b - c) + c * d.
386+
// When a == 1 => A == nullptr, the same applies to d/D as well.
387+
const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
388+
const BinaryOperator *SubBC = nullptr;
389+
390+
const auto MatchSubBC = [&]() {
391+
// (b - c) can have two forms that interest us:
392+
//
393+
// 1. sub nuw %b, %c
394+
// 2. xor %c, %b
395+
//
396+
// For the first case, nuw flag guarantees our requirement b >= c.
397+
//
398+
// The second case might happen when the analysis can infer that b is a mask
399+
// for c and we can transform sub operation into xor (that is usually true
400+
// for constant b's). Even though xor is symmetrical, canonicalization
401+
// ensures that the constant will be the RHS. We have additional checks
402+
// later on to ensure that this xor operation is equivalent to subtraction.
403+
return m_CombineAnd(m_BinOp(SubBC),
404+
m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
405+
m_Xor(m_Value(C), m_Value(B))));
406+
};
407+
408+
const auto MatchASubBC = [&]() {
409+
// Cases:
410+
// - a * (b - c)
411+
// - (b - c) * a
412+
// - (b - c) <- a implicitly equals 1
413+
return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC());
414+
};
415+
416+
const auto MatchCD = [&]() {
417+
// Cases:
418+
// - d * c
419+
// - c * d
420+
// - c <- d implicitly equals 1
421+
return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C));
422+
};
423+
424+
const auto Match = [&](const Value *LHS, const Value *RHS) {
425+
// We do use m_Specific(C) in MatchCD, so we have to make sure that
426+
// it's bound to anything and match(LHS, MatchASubBC()) absolutely
427+
// has to evaluate first and return true.
428+
//
429+
// If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
430+
return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
431+
};
432+
433+
if (!Match(Op0, Op1) && !Match(Op1, Op0))
434+
return;
435+
436+
const auto ComputeKnownBitsOrOne = [&](const Value *V) {
437+
// For some of the values we use the convention of leaving
438+
// it nullptr to signify an implicit constant 1.
439+
return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
440+
: KnownBits::makeConstant(APInt(BitWidth, 1));
441+
};
442+
443+
// Check that all operands are non-negative
444+
const KnownBits KnownA = ComputeKnownBitsOrOne(A);
445+
if (!KnownA.isNonNegative())
446+
return;
447+
448+
const KnownBits KnownD = ComputeKnownBitsOrOne(D);
449+
if (!KnownD.isNonNegative())
450+
return;
451+
452+
const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
453+
if (!KnownB.isNonNegative())
454+
return;
455+
456+
const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
457+
if (!KnownC.isNonNegative())
458+
return;
459+
460+
if (SubBC->getOpcode() == Instruction::Xor) {
461+
// If we matched subtraction as xor, we need to actually check that xor
462+
// is semantically equivalent to subtraction.
463+
//
464+
// For that to be true, b has to be a mask for c.
465+
// In known bits terms it would mean the following:
466+
//
467+
// - b is a constant
468+
if (!KnownB.isConstant())
469+
return;
470+
471+
// - b has ones at least in every position where c might have ones.
472+
const APInt MaxC = KnownC.getMaxValue();
473+
if ((KnownB.getConstant() & MaxC) != MaxC)
474+
return;
475+
}
476+
477+
// Compute max(a, d)
478+
const APInt MaxA = KnownA.getMaxValue();
479+
const APInt MaxD = KnownD.getMaxValue();
480+
const APInt MaxAD = APIntOps::umax(MaxA, MaxD);
481+
482+
// Compute max(a, d) * max(b)
483+
const APInt MaxB = KnownB.getMaxValue();
484+
bool Overflow;
485+
const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow);
486+
487+
if (Overflow)
488+
return;
489+
490+
// Count leading zeros in upper bound
491+
const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
492+
493+
// Create KnownBits with only leading zeros set
494+
KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
495+
496+
return;
497+
}
498+
353499
static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
354500
bool NSW, bool NUW,
355501
const APInt &DemandedElts,
@@ -369,6 +515,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
369515
isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
370516
.value_or(false))
371517
KnownOut.makeNonNegative();
518+
519+
if (Add)
520+
// Try to match lerp pattern and combine results
521+
computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
372522
}
373523

374524
static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,

llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ define i32 @test_clamp(i8 %a, i8 %c, i8 %d) {
1515
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
1616
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
1717
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
18-
; CHECK-NEXT: [[RESULT:%.*]] = call i32 @llvm.umin.i32(i32 [[ADD]], i32 65535)
19-
; CHECK-NEXT: ret i32 [[RESULT]]
18+
; CHECK-NEXT: ret i32 [[ADD]]
2019
;
2120
%a32 = zext i8 %a to i32
2221
%c32 = zext i8 %c to i32
@@ -40,8 +39,7 @@ define i1 @test_trunc_cmp(i8 %a, i8 %c, i8 %d) {
4039
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
4140
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
4241
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
43-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
44-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
42+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
4543
; CHECK-NEXT: ret i1 [[CMP]]
4644
;
4745
%a32 = zext i8 %a to i32
@@ -66,8 +64,7 @@ define i1 @test_trunc_cmp_xor(i8 %a, i8 %c, i8 %d) {
6664
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
6765
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
6866
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
69-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
70-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
67+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
7168
; CHECK-NEXT: ret i1 [[CMP]]
7269
;
7370
%a32 = zext i8 %a to i32
@@ -93,8 +90,7 @@ define i1 @test_trunc_cmp_arbitrary_b(i8 %a, i8 %b, i8 %c, i8 %d) {
9390
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
9491
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
9592
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
96-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
97-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
93+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
9894
; CHECK-NEXT: ret i1 [[CMP]]
9995
;
10096
%a32 = zext i8 %a to i32
@@ -120,8 +116,7 @@ define i1 @test_trunc_cmp_no_a(i8 %b, i8 %c, i8 %d) {
120116
; CHECK-NEXT: [[MUL1:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
121117
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
122118
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
123-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
124-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
119+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
125120
; CHECK-NEXT: ret i1 [[CMP]]
126121
;
127122
%b32 = zext i8 %b to i32
@@ -144,8 +139,7 @@ define i1 @test_trunc_cmp_no_d(i8 %a, i8 %b, i8 %c) {
144139
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
145140
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
146141
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[C32]]
147-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
148-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
142+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
149143
; CHECK-NEXT: ret i1 [[CMP]]
150144
;
151145
%a32 = zext i8 %a to i32
@@ -159,4 +153,29 @@ define i1 @test_trunc_cmp_no_d(i8 %a, i8 %b, i8 %c) {
159153
ret i1 %cmp
160154
}
161155

162-
declare void @llvm.assume(i1)
156+
define i1 @test_trunc_cmp_xor_negative(i8 %a, i8 %c, i8 %d) {
157+
; CHECK-LABEL: define i1 @test_trunc_cmp_xor_negative(
158+
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
159+
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
160+
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
161+
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
162+
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 234
163+
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
164+
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
165+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
166+
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
167+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
168+
; CHECK-NEXT: ret i1 [[CMP]]
169+
;
170+
%a32 = zext i8 %a to i32
171+
%c32 = zext i8 %c to i32
172+
%d32 = zext i8 %d to i32
173+
%sub = xor i32 234, %c32
174+
%mul1 = mul i32 %a32, %sub
175+
%mul2 = mul i32 %c32, %d32
176+
%add = add i32 %mul1, %mul2
177+
; We should keep the trunc in this case
178+
%trunc = trunc i32 %add to i16
179+
%cmp = icmp eq i16 %trunc, 1234
180+
ret i1 %cmp
181+
}

0 commit comments

Comments
 (0)