Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -872,18 +872,32 @@ inline bind_and_match_ty<const Value, MatchTy> m_Value(const Value *&V,

/// Match an instruction, capturing it if we match.
inline bind_ty<Instruction> m_Instruction(Instruction *&I) { return I; }
inline bind_ty<const Instruction> m_Instruction(const Instruction *&I) {
return I;
}

/// Match against the nested pattern, and capture the instruction if we match.
template <typename MatchTy>
inline bind_and_match_ty<Instruction, MatchTy>
m_Instruction(Instruction *&I, const MatchTy &Match) {
return {I, Match};
}
template <typename MatchTy>
inline bind_and_match_ty<const Instruction, MatchTy>
m_Instruction(const Instruction *&I, const MatchTy &Match) {
return {I, Match};
}

/// Match a unary operator, capturing it if we match.
inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; }
inline bind_ty<const UnaryOperator> m_UnOp(const UnaryOperator *&I) {
return I;
}
/// Match a binary operator, capturing it if we match.
inline bind_ty<BinaryOperator> m_BinOp(BinaryOperator *&I) { return I; }
inline bind_ty<const BinaryOperator> m_BinOp(const BinaryOperator *&I) {
return I;
}
/// Match a with overflow intrinsic, capturing it if we match.
inline bind_ty<WithOverflowInst> m_WithOverflowInst(WithOverflowInst *&I) {
return I;
Expand Down
137 changes: 137 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,139 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
return V->getType()->getScalarSizeInBits() - SignBits + 1;
}

/// Try to detect the lerp pattern: a * (b - c) + c * d
/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
///
/// In that particular case, we can use the following chain of reasoning:
///
/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
///
/// Since that is true for arbitrary a, b, c and d within our constraints, we
/// can conclude that:
///
/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
///
/// Considering that any result of the lerp would be less or equal to U, it
/// would have at least the number of leading 0s as in U.
///
/// While being quite a specific situation, it is fairly common in computer
/// graphics in the shape of alpha blending.
///
/// Modifies given KnownOut in-place with the inferred information.
static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
const APInt &DemandedElts,
KnownBits &KnownOut,
const SimplifyQuery &Q,
unsigned Depth) {

Type *Ty = Op0->getType();
const unsigned BitWidth = Ty->getScalarSizeInBits();

// Only handle scalar types for now
if (Ty->isVectorTy())
return;

// Try to match: a * (b - c) + c * d.
// When a == 1 => A == nullptr, the same applies to d/D as well.
const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
const Instruction *SubBC = nullptr;

const auto MatchSubBC = [&]() {
// (b - c) can have two forms that interest us:
//
// 1. sub nuw %b, %c
// 2. xor %c, %b
//
// For the first case, nuw flag guarantees our requirement b >= c.
//
// The second case might happen when the analysis can infer that b is a mask
// for c and we can transform sub operation into xor (that is usually true
// for constant b's). Even though xor is symmetrical, canonicalization
// ensures that the constant will be the RHS. We have additional checks
// later on to ensure that this xor operation is equivalent to subtraction.
return m_Instruction(SubBC, m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
m_Xor(m_Value(C), m_Value(B))));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may fail to match when both b and c are non-constants and c is a subset of b. But I think it is fair enough to handle the constant RHS case first.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really good point. I will follow up with a patch for that because we will need to change how we match quite a lot because both multiplication and xor and commutative and we can't in advance understand what is b and what is c.

};

const auto MatchASubBC = [&]() {
// Cases:
// - a * (b - c)
// - (b - c) * a
// - (b - c) <- a implicitly equals 1
return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC());
};

const auto MatchCD = [&]() {
// Cases:
// - d * c
// - c * d
// - c <- d implicitly equals 1
return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C));
};

const auto Match = [&](const Value *LHS, const Value *RHS) {
// We do use m_Specific(C) in MatchCD, so we have to make sure that
// it's bound to anything and match(LHS, MatchASubBC()) absolutely
// has to evaluate first and return true.
//
// If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
};

if (!Match(Op0, Op1) && !Match(Op1, Op0))
return;

const auto ComputeKnownBitsOrOne = [&](const Value *V) {
// For some of the values we use the convention of leaving
// it nullptr to signify an implicit constant 1.
return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
: KnownBits::makeConstant(APInt(BitWidth, 1));
};

// Check that all operands are non-negative
const KnownBits KnownA = ComputeKnownBitsOrOne(A);
if (!KnownA.isNonNegative())
return;

const KnownBits KnownD = ComputeKnownBitsOrOne(D);
if (!KnownD.isNonNegative())
return;

const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
if (!KnownB.isNonNegative())
return;

const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
if (!KnownC.isNonNegative())
return;

// If we matched subtraction as xor, we need to actually check that xor
// is semantically equivalent to subtraction.
//
// For that to be true, b has to be a mask for c or that b's known
// ones cover all known and possible ones of c.
if (SubBC->getOpcode() == Instruction::Xor &&
!KnownC.getMaxValue().isSubsetOf(KnownB.getMinValue()))
return;

const APInt MaxA = KnownA.getMaxValue();
const APInt MaxD = KnownD.getMaxValue();
const APInt MaxAD = APIntOps::umax(MaxA, MaxD);
const APInt MaxB = KnownB.getMaxValue();

// We can't infer leading zeros info if the upper-bound estimate wraps.
bool Overflow;
const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow);

if (Overflow)
return;

// If we know that x <= y and both are positive than x has at least the same
// number of leading zeros as y.
const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
}

static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
bool NSW, bool NUW,
const APInt &DemandedElts,
Expand All @@ -369,6 +502,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
.value_or(false))
KnownOut.makeNonNegative();

if (Add)
// Try to match lerp pattern and combine results
computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
}

static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
Expand Down
181 changes: 181 additions & 0 deletions llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt < %s -passes=instcombine -S | FileCheck %s

; Test known bits refinements for pattern: a * (b - c) + c * d
; where a > 0, c > 0, b > 0, d > 0, and b > c.
; This pattern is a generalization of lerp and it appears frequently in graphics operations.

define i32 @test_clamp(i8 %a, i8 %c, i8 %d) {
; CHECK-LABEL: define i32 @test_clamp(
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
; CHECK-NEXT: ret i32 [[ADD]]
;
%a32 = zext i8 %a to i32
%c32 = zext i8 %c to i32
%d32 = zext i8 %d to i32
%sub = sub i32 255, %c32
%mul1 = mul i32 %a32, %sub
%mul2 = mul i32 %c32, %d32
%add = add i32 %mul1, %mul2
%cmp = icmp ugt i32 %add, 65535
%result = select i1 %cmp, i32 65535, i32 %add
ret i32 %result
}

define i1 @test_trunc_cmp(i8 %a, i8 %c, i8 %d) {
; CHECK-LABEL: define i1 @test_trunc_cmp(
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
; CHECK-NEXT: ret i1 [[CMP]]
;
%a32 = zext i8 %a to i32
%c32 = zext i8 %c to i32
%d32 = zext i8 %d to i32
%sub = sub i32 255, %c32
%mul1 = mul i32 %a32, %sub
%mul2 = mul i32 %c32, %d32
%add = add i32 %mul1, %mul2
%trunc = trunc i32 %add to i16
%cmp = icmp eq i16 %trunc, 1234
ret i1 %cmp
}

define i1 @test_trunc_cmp_xor(i8 %a, i8 %c, i8 %d) {
; CHECK-LABEL: define i1 @test_trunc_cmp_xor(
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
; CHECK-NEXT: ret i1 [[CMP]]
;
%a32 = zext i8 %a to i32
%c32 = zext i8 %c to i32
%d32 = zext i8 %d to i32
%sub = xor i32 255, %c32
%mul1 = mul i32 %a32, %sub
%mul2 = mul i32 %c32, %d32
%add = add i32 %mul1, %mul2
%trunc = trunc i32 %add to i16
%cmp = icmp eq i16 %trunc, 1234
ret i1 %cmp
}

define i1 @test_trunc_cmp_arbitrary_b(i8 %a, i8 %b, i8 %c, i8 %d) {
; CHECK-LABEL: define i1 @test_trunc_cmp_arbitrary_b(
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
; CHECK-NEXT: ret i1 [[CMP]]
;
%a32 = zext i8 %a to i32
%b32 = zext i8 %b to i32
%c32 = zext i8 %c to i32
%d32 = zext i8 %d to i32
%sub = sub nsw nuw i32 %b32, %c32
%mul1 = mul i32 %a32, %sub
%mul2 = mul i32 %c32, %d32
%add = add i32 %mul1, %mul2
%trunc = trunc i32 %add to i16
%cmp = icmp eq i16 %trunc, 1234
ret i1 %cmp
}


define i1 @test_trunc_cmp_no_a(i8 %b, i8 %c, i8 %d) {
; CHECK-LABEL: define i1 @test_trunc_cmp_no_a(
; CHECK-SAME: i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
; CHECK-NEXT: [[MUL1:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
; CHECK-NEXT: ret i1 [[CMP]]
;
%b32 = zext i8 %b to i32
%c32 = zext i8 %c to i32
%d32 = zext i8 %d to i32
%sub = sub nuw i32 %b32, %c32
%mul2 = mul i32 %c32, %d32
%add = add i32 %sub, %mul2
%trunc = trunc i32 %add to i16
%cmp = icmp eq i16 %trunc, 1234
ret i1 %cmp
}

define i1 @test_trunc_cmp_no_d(i8 %a, i8 %b, i8 %c) {
; CHECK-LABEL: define i1 @test_trunc_cmp_no_d(
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]]
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[C32]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234
; CHECK-NEXT: ret i1 [[CMP]]
;
%a32 = zext i8 %a to i32
%b32 = zext i8 %b to i32
%c32 = zext i8 %c to i32
%sub = sub nsw nuw i32 %b32, %c32
%mul1 = mul i32 %a32, %sub
%add = add i32 %mul1, %c32
%trunc = trunc i32 %add to i16
%cmp = icmp eq i16 %trunc, 1234
ret i1 %cmp
}

define i1 @test_trunc_cmp_xor_negative(i8 %a, i8 %c, i8 %d) {
; CHECK-LABEL: define i1 @test_trunc_cmp_xor_negative(
; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) {
; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32
; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32
; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 234
; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]]
; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]]
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]]
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234
; CHECK-NEXT: ret i1 [[CMP]]
;
%a32 = zext i8 %a to i32
%c32 = zext i8 %c to i32
%d32 = zext i8 %d to i32
%sub = xor i32 234, %c32
%mul1 = mul i32 %a32, %sub
%mul2 = mul i32 %c32, %d32
%add = add i32 %mul1, %mul2
; We should keep the trunc in this case
%trunc = trunc i32 %add to i16
%cmp = icmp eq i16 %trunc, 1234
ret i1 %cmp
}