@@ -350,6 +350,136 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
350350 return V->getType ()->getScalarSizeInBits () - SignBits + 1 ;
351351}
352352
353+ // / Try to detect the lerp pattern: a * (b - c) + c * d
354+ // / where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355+ // /
356+ // / In that particular case, we can use the following chain of reasoning:
357+ // /
358+ // / a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359+ // /
360+ // / Since that is true for arbitrary a, b, c and d within our constraints, we
361+ // / can conclude that:
362+ // /
363+ // / max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364+ // /
365+ // / Considering that any result of the lerp would be less or equal to U, it
366+ // / would have at least the number of leading 0s as in U.
367+ // /
368+ // / While being quite a specific situation, it is fairly common in computer
369+ // / graphics in the shape of alpha blending.
370+ // /
371+ // / Returns unknown bits if the pattern doesn't match or constraints don't apply
372+ // / to the given operands.
373+ static KnownBits computeKnownBitsFromLerpPattern (const Value *Op0,
374+ const Value *Op1,
375+ const APInt &DemandedElts,
376+ const SimplifyQuery &Q,
377+ unsigned Depth) {
378+
379+ Type *Ty = Op0->getType ();
380+ const unsigned BitWidth = Ty->getScalarSizeInBits ();
381+
382+ KnownBits Result (BitWidth);
383+
384+ // Only handle scalar types for now
385+ if (Ty->isVectorTy ())
386+ return Result;
387+
388+ // Try to match: a * (b - c) + c * d.
389+ // When a == 1 => A == nullptr, the same applies to d/D as well.
390+ const Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr ;
391+
392+ const auto MatchSubBC = [&]() {
393+ // (b - c) can have two forms that interest us:
394+ //
395+ // 1. sub nuw %b, %c
396+ // 2. xor %c, %b
397+ //
398+ // For the first case, nuw flag guarantees our requirement b >= c.
399+ //
400+ // The second case happens when the analysis can infer that b is a mask for
401+ // c and we can transform sub operation into xor (that is usually true for
402+ // constant b's). Even though xor is symmetrical, canonicalization ensures
403+ // that the constant will be the RHS. xor of two positive integers is
404+ // guaranteed to be non-negative as well.
405+ return m_CombineOr (m_NUWSub (m_Value (B), m_Value (C)),
406+ m_Xor (m_Value (C), m_Value (B)));
407+ };
408+
409+ const auto MatchASubBC = [&]() {
410+ // Cases:
411+ // - a * (b - c)
412+ // - (b - c) * a
413+ // - (b - c) <- a implicitly equals 1
414+ return m_CombineOr (m_c_Mul (m_Value (A), MatchSubBC ()), MatchSubBC ());
415+ };
416+
417+ const auto MatchCD = [&]() {
418+ // Cases:
419+ // - d * c
420+ // - c * d
421+ // - c <- d implicitly equals 1
422+ return m_CombineOr (m_c_Mul (m_Value (D), m_Specific (C)), m_Specific (C));
423+ };
424+
425+ const auto Match = [&](const Value *LHS, const Value *RHS) {
426+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
427+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
428+ // has to evaluate first and return true.
429+ //
430+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
431+ return match (LHS, MatchASubBC ()) && match (RHS, MatchCD ());
432+ };
433+
434+ if (!Match (Op0, Op1) && !Match (Op1, Op0))
435+ return Result;
436+
437+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
438+ // For some of the values we use the convention of leaving
439+ // it nullptr to signify an implicit constant 1.
440+ return V ? computeKnownBits (V, DemandedElts, Q, Depth + 1 )
441+ : KnownBits::makeConstant (APInt (BitWidth, 1 ));
442+ };
443+
444+ // Check that all operands are non-negative
445+ const KnownBits KnownA = ComputeKnownBitsOrOne (A);
446+ if (!KnownA.isNonNegative ())
447+ return Result;
448+
449+ const KnownBits KnownD = ComputeKnownBitsOrOne (D);
450+ if (!KnownD.isNonNegative ())
451+ return Result;
452+
453+ const KnownBits KnownB = computeKnownBits (B, DemandedElts, Q, Depth + 1 );
454+ if (!KnownB.isNonNegative ())
455+ return Result;
456+
457+ const KnownBits KnownC = computeKnownBits (C, DemandedElts, Q, Depth + 1 );
458+ if (!KnownC.isNonNegative ())
459+ return Result;
460+
461+ // Compute max(a, d)
462+ const APInt MaxA = KnownA.getMaxValue ();
463+ const APInt MaxD = KnownD.getMaxValue ();
464+ const APInt MaxAD = APIntOps::umax (MaxA, MaxD);
465+
466+ // Compute max(a, d) * max(b)
467+ const APInt MaxB = KnownB.getMaxValue ();
468+ bool Overflow;
469+ const APInt UpperBound = MaxAD.umul_ov (MaxB, Overflow);
470+
471+ if (Overflow)
472+ return Result;
473+
474+ // Count leading zeros in upper bound
475+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero ();
476+
477+ // Create KnownBits with only leading zeros set
478+ Result.Zero .setHighBits (MinimumNumberOfLeadingZeros);
479+
480+ return Result;
481+ }
482+
353483static void computeKnownBitsAddSub (bool Add, const Value *Op0, const Value *Op1,
354484 bool NSW, bool NUW,
355485 const APInt &DemandedElts,
@@ -369,6 +499,15 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
369499 isImpliedByDomCondition (ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI , Q.DL )
370500 .value_or (false ))
371501 KnownOut.makeNonNegative ();
502+
503+ if (Add) {
504+ // Try to match lerp pattern and combine results
505+ const KnownBits LerpKnown =
506+ computeKnownBitsFromLerpPattern (Op0, Op1, DemandedElts, Q, Depth);
507+ // Union of any two conservative estimates results in a conservative
508+ // estimate that is at least as precise as each individual estimate.
509+ KnownOut = KnownOut.unionWith (LerpKnown);
510+ }
372511}
373512
374513static void computeKnownBitsMul (const Value *Op0, const Value *Op1, bool NSW,
0 commit comments