@@ -350,6 +350,141 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
350350 return V->getType ()->getScalarSizeInBits () - SignBits + 1 ;
351351}
352352
353+ // / Try to detect the lerp pattern: a * (b - c) + c * d
354+ // / where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355+ // /
356+ // / In that particular case, we can use the following chain of reasoning:
357+ // /
358+ // / a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359+ // /
360+ // / Since that is true for arbitrary a, b, c and d within our constraints, we
361+ // / can conclude that:
362+ // /
363+ // / max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364+ // /
365+ // / Considering that any result of the lerp would be less or equal to U, it
366+ // / would have at least the number of leading 0s as in U.
367+ // /
368+ // / While being quite a specific situation, it is fairly common in computer
369+ // / graphics in the shape of alpha blending.
370+ // /
371+ // / Modifies given KnownOut in-place with the inferred information.
372+ static void computeKnownBitsFromLerpPattern (const Value *Op0, const Value *Op1,
373+ const APInt &DemandedElts,
374+ KnownBits &KnownOut,
375+ const SimplifyQuery &Q,
376+ unsigned Depth) {
377+
378+ Type *Ty = Op0->getType ();
379+ const unsigned BitWidth = Ty->getScalarSizeInBits ();
380+
381+ // Only handle scalar types for now
382+ if (Ty->isVectorTy ())
383+ return ;
384+
385+ // Try to match: a * (b - c) + c * d.
386+ // When a == 1 => A == nullptr, the same applies to d/D as well.
387+ const Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr ;
388+ const Instruction *SubBC = nullptr ;
389+
390+ const auto MatchSubBC = [&]() {
391+ // (b - c) can have two forms that interest us:
392+ //
393+ // 1. sub nuw %b, %c
394+ // 2. xor %c, %b
395+ //
396+ // For the first case, nuw flag guarantees our requirement b >= c.
397+ //
398+ // The second case might happen when the analysis can infer that b is a mask
399+ // for c and we can transform sub operation into xor (that is usually true
400+ // for constant b's). Even though xor is symmetrical, canonicalization
401+ // ensures that the constant will be the RHS. We have additional checks
402+ // later on to ensure that this xor operation is equivalent to subtraction.
403+ return m_Instruction (SubBC, m_CombineOr (m_NUWSub (m_Value (B), m_Value (C)),
404+ m_Xor (m_Value (C), m_Value (B))));
405+ };
406+
407+ const auto MatchASubBC = [&]() {
408+ // Cases:
409+ // - a * (b - c)
410+ // - (b - c) * a
411+ // - (b - c) <- a implicitly equals 1
412+ return m_CombineOr (m_c_Mul (m_Value (A), MatchSubBC ()), MatchSubBC ());
413+ };
414+
415+ const auto MatchCD = [&]() {
416+ // Cases:
417+ // - d * c
418+ // - c * d
419+ // - c <- d implicitly equals 1
420+ return m_CombineOr (m_c_Mul (m_Value (D), m_Specific (C)), m_Specific (C));
421+ };
422+
423+ const auto Match = [&](const Value *LHS, const Value *RHS) {
424+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
425+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
426+ // has to evaluate first and return true.
427+ //
428+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
429+ return match (LHS, MatchASubBC ()) && match (RHS, MatchCD ());
430+ };
431+
432+ if (!Match (Op0, Op1) && !Match (Op1, Op0))
433+ return ;
434+
435+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
436+ // For some of the values we use the convention of leaving
437+ // it nullptr to signify an implicit constant 1.
438+ return V ? computeKnownBits (V, DemandedElts, Q, Depth + 1 )
439+ : KnownBits::makeConstant (APInt (BitWidth, 1 ));
440+ };
441+
442+ // Check that all operands are non-negative
443+ const KnownBits KnownA = ComputeKnownBitsOrOne (A);
444+ if (!KnownA.isNonNegative ())
445+ return ;
446+
447+ const KnownBits KnownD = ComputeKnownBitsOrOne (D);
448+ if (!KnownD.isNonNegative ())
449+ return ;
450+
451+ const KnownBits KnownB = computeKnownBits (B, DemandedElts, Q, Depth + 1 );
452+ if (!KnownB.isNonNegative ())
453+ return ;
454+
455+ const KnownBits KnownC = computeKnownBits (C, DemandedElts, Q, Depth + 1 );
456+ if (!KnownC.isNonNegative ())
457+ return ;
458+
459+ // If we matched subtraction as xor, we need to actually check that xor
460+ // is semantically equivalent to subtraction.
461+ //
462+ // For that to be true, b has to be a mask for c or that b's known
463+ // ones cover all known and possible ones of c.
464+ if (SubBC->getOpcode () == Instruction::Xor &&
465+ !KnownC.getMaxValue ().isSubsetOf (KnownB.getMinValue ()))
466+ return ;
467+
468+ // Compute max(a, d)
469+ const APInt MaxA = KnownA.getMaxValue ();
470+ const APInt MaxD = KnownD.getMaxValue ();
471+ const APInt MaxAD = APIntOps::umax (MaxA, MaxD);
472+
473+ // Compute max(a, d) * max(b)
474+ const APInt MaxB = KnownB.getMaxValue ();
475+ bool Overflow;
476+ const APInt UpperBound = MaxAD.umul_ov (MaxB, Overflow);
477+
478+ if (Overflow)
479+ return ;
480+
481+ // Count leading zeros in upper bound
482+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero ();
483+
484+ // Create KnownBits with only leading zeros set
485+ KnownOut.Zero .setHighBits (MinimumNumberOfLeadingZeros);
486+ }
487+
353488static void computeKnownBitsAddSub (bool Add, const Value *Op0, const Value *Op1,
354489 bool NSW, bool NUW,
355490 const APInt &DemandedElts,
@@ -369,6 +504,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
369504 isImpliedByDomCondition (ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI , Q.DL )
370505 .value_or (false ))
371506 KnownOut.makeNonNegative ();
507+
508+ if (Add)
509+ // Try to match lerp pattern and combine results
510+ computeKnownBitsFromLerpPattern (Op0, Op1, DemandedElts, KnownOut, Q, Depth);
372511}
373512
374513static void computeKnownBitsMul (const Value *Op0, const Value *Op1, bool NSW,
0 commit comments