diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index bb3d7fff5c9bc..88aef4a368f29 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -872,6 +872,9 @@ inline bind_and_match_ty m_Value(const Value *&V, /// Match an instruction, capturing it if we match. inline bind_ty m_Instruction(Instruction *&I) { return I; } +inline bind_ty m_Instruction(const Instruction *&I) { + return I; +} /// Match against the nested pattern, and capture the instruction if we match. template @@ -879,11 +882,22 @@ inline bind_and_match_ty m_Instruction(Instruction *&I, const MatchTy &Match) { return {I, Match}; } +template +inline bind_and_match_ty +m_Instruction(const Instruction *&I, const MatchTy &Match) { + return {I, Match}; +} /// Match a unary operator, capturing it if we match. inline bind_ty m_UnOp(UnaryOperator *&I) { return I; } +inline bind_ty m_UnOp(const UnaryOperator *&I) { + return I; +} /// Match a binary operator, capturing it if we match. inline bind_ty m_BinOp(BinaryOperator *&I) { return I; } +inline bind_ty m_BinOp(const BinaryOperator *&I) { + return I; +} /// Match a with overflow intrinsic, capturing it if we match. inline bind_ty m_WithOverflowInst(WithOverflowInst *&I) { return I; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 789a98366cead..41ff816a33262 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -350,6 +350,139 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL, return V->getType()->getScalarSizeInBits() - SignBits + 1; } +/// Try to detect the lerp pattern: a * (b - c) + c * d +/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c. +/// +/// In that particular case, we can use the following chain of reasoning: +/// +/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d) +/// +/// Since that is true for arbitrary a, b, c and d within our constraints, we +/// can conclude that: +/// +/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U +/// +/// Considering that any result of the lerp would be less or equal to U, it +/// would have at least the number of leading 0s as in U. +/// +/// While being quite a specific situation, it is fairly common in computer +/// graphics in the shape of alpha blending. +/// +/// Modifies given KnownOut in-place with the inferred information. +static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1, + const APInt &DemandedElts, + KnownBits &KnownOut, + const SimplifyQuery &Q, + unsigned Depth) { + + Type *Ty = Op0->getType(); + const unsigned BitWidth = Ty->getScalarSizeInBits(); + + // Only handle scalar types for now + if (Ty->isVectorTy()) + return; + + // Try to match: a * (b - c) + c * d. + // When a == 1 => A == nullptr, the same applies to d/D as well. + const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; + const Instruction *SubBC = nullptr; + + const auto MatchSubBC = [&]() { + // (b - c) can have two forms that interest us: + // + // 1. sub nuw %b, %c + // 2. xor %c, %b + // + // For the first case, nuw flag guarantees our requirement b >= c. + // + // The second case might happen when the analysis can infer that b is a mask + // for c and we can transform sub operation into xor (that is usually true + // for constant b's). Even though xor is symmetrical, canonicalization + // ensures that the constant will be the RHS. We have additional checks + // later on to ensure that this xor operation is equivalent to subtraction. + return m_Instruction(SubBC, m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)), + m_Xor(m_Value(C), m_Value(B)))); + }; + + const auto MatchASubBC = [&]() { + // Cases: + // - a * (b - c) + // - (b - c) * a + // - (b - c) <- a implicitly equals 1 + return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC()); + }; + + const auto MatchCD = [&]() { + // Cases: + // - d * c + // - c * d + // - c <- d implicitly equals 1 + return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C)); + }; + + const auto Match = [&](const Value *LHS, const Value *RHS) { + // We do use m_Specific(C) in MatchCD, so we have to make sure that + // it's bound to anything and match(LHS, MatchASubBC()) absolutely + // has to evaluate first and return true. + // + // If Match returns true, it is guaranteed that B != nullptr, C != nullptr. + return match(LHS, MatchASubBC()) && match(RHS, MatchCD()); + }; + + if (!Match(Op0, Op1) && !Match(Op1, Op0)) + return; + + const auto ComputeKnownBitsOrOne = [&](const Value *V) { + // For some of the values we use the convention of leaving + // it nullptr to signify an implicit constant 1. + return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1) + : KnownBits::makeConstant(APInt(BitWidth, 1)); + }; + + // Check that all operands are non-negative + const KnownBits KnownA = ComputeKnownBitsOrOne(A); + if (!KnownA.isNonNegative()) + return; + + const KnownBits KnownD = ComputeKnownBitsOrOne(D); + if (!KnownD.isNonNegative()) + return; + + const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1); + if (!KnownB.isNonNegative()) + return; + + const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1); + if (!KnownC.isNonNegative()) + return; + + // If we matched subtraction as xor, we need to actually check that xor + // is semantically equivalent to subtraction. + // + // For that to be true, b has to be a mask for c or that b's known + // ones cover all known and possible ones of c. + if (SubBC->getOpcode() == Instruction::Xor && + !KnownC.getMaxValue().isSubsetOf(KnownB.getMinValue())) + return; + + const APInt MaxA = KnownA.getMaxValue(); + const APInt MaxD = KnownD.getMaxValue(); + const APInt MaxAD = APIntOps::umax(MaxA, MaxD); + const APInt MaxB = KnownB.getMaxValue(); + + // We can't infer leading zeros info if the upper-bound estimate wraps. + bool Overflow; + const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow); + + if (Overflow) + return; + + // If we know that x <= y and both are positive than x has at least the same + // number of leading zeros as y. + const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero(); + KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros); +} + static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, bool NSW, bool NUW, const APInt &DemandedElts, @@ -369,6 +502,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL) .value_or(false)) KnownOut.makeNonNegative(); + + if (Add) + // Try to match lerp pattern and combine results + computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth); } static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, diff --git a/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll b/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll new file mode 100644 index 0000000000000..5a33d35aa1cf1 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/known-bits-lerp-pattern.ll @@ -0,0 +1,181 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +; Test known bits refinements for pattern: a * (b - c) + c * d +; where a > 0, c > 0, b > 0, d > 0, and b > c. +; This pattern is a generalization of lerp and it appears frequently in graphics operations. + +define i32 @test_clamp(i8 %a, i8 %c, i8 %d) { +; CHECK-LABEL: define i32 @test_clamp( +; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) { +; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32 +; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32 +; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32 +; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255 +; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]] +; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]] +; CHECK-NEXT: ret i32 [[ADD]] +; + %a32 = zext i8 %a to i32 + %c32 = zext i8 %c to i32 + %d32 = zext i8 %d to i32 + %sub = sub i32 255, %c32 + %mul1 = mul i32 %a32, %sub + %mul2 = mul i32 %c32, %d32 + %add = add i32 %mul1, %mul2 + %cmp = icmp ugt i32 %add, 65535 + %result = select i1 %cmp, i32 65535, i32 %add + ret i32 %result +} + +define i1 @test_trunc_cmp(i8 %a, i8 %c, i8 %d) { +; CHECK-LABEL: define i1 @test_trunc_cmp( +; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) { +; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32 +; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32 +; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32 +; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255 +; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]] +; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234 +; CHECK-NEXT: ret i1 [[CMP]] +; + %a32 = zext i8 %a to i32 + %c32 = zext i8 %c to i32 + %d32 = zext i8 %d to i32 + %sub = sub i32 255, %c32 + %mul1 = mul i32 %a32, %sub + %mul2 = mul i32 %c32, %d32 + %add = add i32 %mul1, %mul2 + %trunc = trunc i32 %add to i16 + %cmp = icmp eq i16 %trunc, 1234 + ret i1 %cmp +} + +define i1 @test_trunc_cmp_xor(i8 %a, i8 %c, i8 %d) { +; CHECK-LABEL: define i1 @test_trunc_cmp_xor( +; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) { +; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32 +; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32 +; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32 +; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 255 +; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]] +; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234 +; CHECK-NEXT: ret i1 [[CMP]] +; + %a32 = zext i8 %a to i32 + %c32 = zext i8 %c to i32 + %d32 = zext i8 %d to i32 + %sub = xor i32 255, %c32 + %mul1 = mul i32 %a32, %sub + %mul2 = mul i32 %c32, %d32 + %add = add i32 %mul1, %mul2 + %trunc = trunc i32 %add to i16 + %cmp = icmp eq i16 %trunc, 1234 + ret i1 %cmp +} + +define i1 @test_trunc_cmp_arbitrary_b(i8 %a, i8 %b, i8 %c, i8 %d) { +; CHECK-LABEL: define i1 @test_trunc_cmp_arbitrary_b( +; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) { +; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32 +; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32 +; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32 +; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32 +; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]] +; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]] +; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234 +; CHECK-NEXT: ret i1 [[CMP]] +; + %a32 = zext i8 %a to i32 + %b32 = zext i8 %b to i32 + %c32 = zext i8 %c to i32 + %d32 = zext i8 %d to i32 + %sub = sub nsw nuw i32 %b32, %c32 + %mul1 = mul i32 %a32, %sub + %mul2 = mul i32 %c32, %d32 + %add = add i32 %mul1, %mul2 + %trunc = trunc i32 %add to i16 + %cmp = icmp eq i16 %trunc, 1234 + ret i1 %cmp +} + + +define i1 @test_trunc_cmp_no_a(i8 %b, i8 %c, i8 %d) { +; CHECK-LABEL: define i1 @test_trunc_cmp_no_a( +; CHECK-SAME: i8 [[B:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) { +; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32 +; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32 +; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32 +; CHECK-NEXT: [[MUL1:%.*]] = sub nuw nsw i32 [[B32]], [[C32]] +; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234 +; CHECK-NEXT: ret i1 [[CMP]] +; + %b32 = zext i8 %b to i32 + %c32 = zext i8 %c to i32 + %d32 = zext i8 %d to i32 + %sub = sub nuw i32 %b32, %c32 + %mul2 = mul i32 %c32, %d32 + %add = add i32 %sub, %mul2 + %trunc = trunc i32 %add to i16 + %cmp = icmp eq i16 %trunc, 1234 + ret i1 %cmp +} + +define i1 @test_trunc_cmp_no_d(i8 %a, i8 %b, i8 %c) { +; CHECK-LABEL: define i1 @test_trunc_cmp_no_d( +; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) { +; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32 +; CHECK-NEXT: [[B32:%.*]] = zext i8 [[B]] to i32 +; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32 +; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 [[B32]], [[C32]] +; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[C32]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ADD]], 1234 +; CHECK-NEXT: ret i1 [[CMP]] +; + %a32 = zext i8 %a to i32 + %b32 = zext i8 %b to i32 + %c32 = zext i8 %c to i32 + %sub = sub nsw nuw i32 %b32, %c32 + %mul1 = mul i32 %a32, %sub + %add = add i32 %mul1, %c32 + %trunc = trunc i32 %add to i16 + %cmp = icmp eq i16 %trunc, 1234 + ret i1 %cmp +} + +define i1 @test_trunc_cmp_xor_negative(i8 %a, i8 %c, i8 %d) { +; CHECK-LABEL: define i1 @test_trunc_cmp_xor_negative( +; CHECK-SAME: i8 [[A:%.*]], i8 [[C:%.*]], i8 [[D:%.*]]) { +; CHECK-NEXT: [[A32:%.*]] = zext i8 [[A]] to i32 +; CHECK-NEXT: [[C32:%.*]] = zext i8 [[C]] to i32 +; CHECK-NEXT: [[D32:%.*]] = zext i8 [[D]] to i32 +; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[C32]], 234 +; CHECK-NEXT: [[MUL1:%.*]] = mul nuw nsw i32 [[SUB]], [[A32]] +; CHECK-NEXT: [[MUL2:%.*]] = mul nuw nsw i32 [[C32]], [[D32]] +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[MUL1]], [[MUL2]] +; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ADD]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[TRUNC]], 1234 +; CHECK-NEXT: ret i1 [[CMP]] +; + %a32 = zext i8 %a to i32 + %c32 = zext i8 %c to i32 + %d32 = zext i8 %d to i32 + %sub = xor i32 234, %c32 + %mul1 = mul i32 %a32, %sub + %mul2 = mul i32 %c32, %d32 + %add = add i32 %mul1, %mul2 + ; We should keep the trunc in this case + %trunc = trunc i32 %add to i16 + %cmp = icmp eq i16 %trunc, 1234 + ret i1 %cmp +}