@@ -350,6 +350,140 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
350350 return V->getType ()->getScalarSizeInBits () - SignBits + 1 ;
351351}
352352
353+ // Try to detect the lerp pattern: a * (b - c) + c * d
354+ // where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355+ //
356+ // In that particular case, we can use the following chain of reasoning:
357+ //
358+ // a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359+ //
360+ // Since that is true for arbitrary a, b, c and d within our constraints, we can
361+ // conclude that:
362+ //
363+ // max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364+ //
365+ // Considering that any result of the lerp would be less or equal to U, it would
366+ // have at least the number of leading 0s as in U.
367+ //
368+ // While being quite a specific situation, it is fairly common in computer
369+ // graphics in the shape of alpha blending.
370+ //
371+ // Returns unknown bits if the pattern doesn't match or constraints don't apply
372+ // to the given operands.
373+ static KnownBits computeKnownBitsFromLerpPattern (const Value *Op0,
374+ const Value *Op1,
375+ const APInt &DemandedElts,
376+ const SimplifyQuery &Q,
377+ unsigned Depth) {
378+
379+ Type *Ty = Op0->getType ();
380+ const unsigned BitWidth = Ty->getScalarSizeInBits ();
381+
382+ KnownBits Result (BitWidth);
383+
384+ // Only handle scalar types for now
385+ if (Ty->isVectorTy ())
386+ return Result;
387+
388+ // Try to match: a * (b - c) + c * d.
389+ // When a == 1 => A == nullptr, the same applies to d/D as well.
390+ const Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr ;
391+
392+ const auto MatchSubBC = [&]() {
393+ // (b - c) can have two forms that interest us:
394+ //
395+ // 1. sub nuw %b, %c
396+ // 2. xor %c, %b
397+ //
398+ // For the first case, nuw flag guarantees our requirement b >= c.
399+ //
400+ // The second case happens when the analysis can infer that b is a mask for
401+ // c and we can transform sub operation into xor (that is usually true for
402+ // constant b's). Even though xor is symmetrical, canonicalization ensures
403+ // that the constant will be the RHS. xor of two positive integers is
404+ // guaranteed to be non-negative as well.
405+ return m_CombineOr (m_NUWSub (m_Value (B), m_Value (C)),
406+ m_Xor (m_Value (C), m_Value (B)));
407+ };
408+
409+ const auto MatchASubBC = [&]() {
410+ // Cases:
411+ // - a * (b - c)
412+ // - (b - c) * a
413+ // - (b - c) <- a implicitly equals 1
414+ return m_CombineOr (m_CombineOr (m_Mul (m_Value (A), MatchSubBC ()),
415+ m_Mul (MatchSubBC (), m_Value (A))),
416+ MatchSubBC ());
417+ };
418+
419+ const auto MatchCD = [&]() {
420+ // Cases:
421+ // - d * c
422+ // - c * d
423+ // - c <- d implicitly equals 1
424+ return m_CombineOr (m_CombineOr (m_Mul (m_Value (D), m_Specific (C)),
425+ m_Mul (m_Specific (C), m_Value (D))),
426+ m_Specific (C));
427+ };
428+
429+ const auto Match = [&](const Value *LHS, const Value *RHS) {
430+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
431+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
432+ // has to evaluate first and return true.
433+ //
434+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
435+ return match (LHS, MatchASubBC ()) && match (RHS, MatchCD ());
436+ };
437+
438+ if (!Match (Op0, Op1) && !Match (Op1, Op0))
439+ return Result;
440+
441+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
442+ // For some of the values we use the convention of leaving
443+ // it nullptr to signify an implicit constant 1.
444+ return V ? computeKnownBits (V, DemandedElts, Q, Depth + 1 )
445+ : KnownBits::makeConstant (APInt (BitWidth, 1 ));
446+ };
447+
448+ // Check that all operands are non-negative
449+ const KnownBits KnownA = ComputeKnownBitsOrOne (A);
450+ if (!KnownA.isNonNegative ())
451+ return Result;
452+
453+ const KnownBits KnownD = ComputeKnownBitsOrOne (D);
454+ if (!KnownD.isNonNegative ())
455+ return Result;
456+
457+ const KnownBits KnownB = computeKnownBits (B, DemandedElts, Q, Depth + 1 );
458+ if (!KnownB.isNonNegative ())
459+ return Result;
460+
461+ const KnownBits KnownC = computeKnownBits (C, DemandedElts, Q, Depth + 1 );
462+ if (!KnownC.isNonNegative ())
463+ return Result;
464+
465+ // Compute max(a, d)
466+ const APInt MaxA = KnownA.getMaxValue ();
467+ const APInt MaxD = KnownD.getMaxValue ();
468+ const APInt MaxAD = MaxA.ult (MaxD) ? MaxD : MaxA;
469+
470+ // Compute max(a, d) * max(b)
471+ const APInt MaxB = KnownB.getMaxValue ();
472+ bool Overflow;
473+ const APInt UpperBound = MaxAD.umul_ov (MaxB, Overflow);
474+
475+ if (Overflow)
476+ return Result;
477+
478+ // Count leading zeros in upper bound
479+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero ();
480+
481+ // Create KnownBits with only leading zeros set
482+ Result.Zero .setHighBits (MinimumNumberOfLeadingZeros);
483+
484+ return Result;
485+ }
486+
353487static void computeKnownBitsAddSub (bool Add, const Value *Op0, const Value *Op1,
354488 bool NSW, bool NUW,
355489 const APInt &DemandedElts,
@@ -369,6 +503,15 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
369503 isImpliedByDomCondition (ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI , Q.DL )
370504 .value_or (false ))
371505 KnownOut.makeNonNegative ();
506+
507+ if (Add) {
508+ // Try to match lerp pattern and combine results
509+ const KnownBits LerpKnown =
510+ computeKnownBitsFromLerpPattern (Op0, Op1, DemandedElts, Q, Depth);
511+ // Union of any two conservative estimates results in a conservative
512+ // estimate that is at least as precise as each individual estimate.
513+ KnownOut = KnownOut.unionWith (LerpKnown);
514+ }
372515}
373516
374517static void computeKnownBitsMul (const Value *Op0, const Value *Op1, bool NSW,
0 commit comments