@@ -350,6 +350,153 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
350350 return V->getType ()->getScalarSizeInBits () - SignBits + 1 ;
351351}
352352
353+ // / Try to detect the lerp pattern: a * (b - c) + c * d
354+ // / where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355+ // /
356+ // / In that particular case, we can use the following chain of reasoning:
357+ // /
358+ // / a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359+ // /
360+ // / Since that is true for arbitrary a, b, c and d within our constraints, we
361+ // / can conclude that:
362+ // /
363+ // / max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364+ // /
365+ // / Considering that any result of the lerp would be less or equal to U, it
366+ // / would have at least the number of leading 0s as in U.
367+ // /
368+ // / While being quite a specific situation, it is fairly common in computer
369+ // / graphics in the shape of alpha blending.
370+ // /
371+ // / Returns unknown bits if the pattern doesn't match or constraints don't apply
372+ // / to the given operands.
373+ static void computeKnownBitsFromLerpPattern (const Value *Op0, const Value *Op1,
374+ const APInt &DemandedElts,
375+ KnownBits &KnownOut,
376+ const SimplifyQuery &Q,
377+ unsigned Depth) {
378+
379+ Type *Ty = Op0->getType ();
380+ const unsigned BitWidth = Ty->getScalarSizeInBits ();
381+
382+ // Only handle scalar types for now
383+ if (Ty->isVectorTy ())
384+ return ;
385+
386+ // Try to match: a * (b - c) + c * d.
387+ // When a == 1 => A == nullptr, the same applies to d/D as well.
388+ const Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr ;
389+ const BinaryOperator *SubBC = nullptr ;
390+
391+ const auto MatchSubBC = [&]() {
392+ // (b - c) can have two forms that interest us:
393+ //
394+ // 1. sub nuw %b, %c
395+ // 2. xor %c, %b
396+ //
397+ // For the first case, nuw flag guarantees our requirement b >= c.
398+ //
399+ // The second case happens when the analysis can infer that b is a mask for
400+ // c and we can transform sub operation into xor (that is usually true for
401+ // constant b's). Even though xor is symmetrical, canonicalization ensures
402+ // that the constant will be the RHS. xor of two positive integers is
403+ // guaranteed to be non-negative as well.
404+ return m_CombineAnd (m_BinOp (SubBC),
405+ m_CombineOr (m_NUWSub (m_Value (B), m_Value (C)),
406+ m_Xor (m_Value (C), m_Value (B))));
407+ };
408+
409+ const auto MatchASubBC = [&]() {
410+ // Cases:
411+ // - a * (b - c)
412+ // - (b - c) * a
413+ // - (b - c) <- a implicitly equals 1
414+ return m_CombineOr (m_c_Mul (m_Value (A), MatchSubBC ()), MatchSubBC ());
415+ };
416+
417+ const auto MatchCD = [&]() {
418+ // Cases:
419+ // - d * c
420+ // - c * d
421+ // - c <- d implicitly equals 1
422+ return m_CombineOr (m_c_Mul (m_Value (D), m_Specific (C)), m_Specific (C));
423+ };
424+
425+ const auto Match = [&](const Value *LHS, const Value *RHS) {
426+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
427+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
428+ // has to evaluate first and return true.
429+ //
430+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
431+ return match (LHS, MatchASubBC ()) && match (RHS, MatchCD ());
432+ };
433+
434+ if (!Match (Op0, Op1) && !Match (Op1, Op0))
435+ return ;
436+
437+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
438+ // For some of the values we use the convention of leaving
439+ // it nullptr to signify an implicit constant 1.
440+ return V ? computeKnownBits (V, DemandedElts, Q, Depth + 1 )
441+ : KnownBits::makeConstant (APInt (BitWidth, 1 ));
442+ };
443+
444+ // Check that all operands are non-negative
445+ const KnownBits KnownA = ComputeKnownBitsOrOne (A);
446+ if (!KnownA.isNonNegative ())
447+ return ;
448+
449+ const KnownBits KnownD = ComputeKnownBitsOrOne (D);
450+ if (!KnownD.isNonNegative ())
451+ return ;
452+
453+ const KnownBits KnownB = computeKnownBits (B, DemandedElts, Q, Depth + 1 );
454+ if (!KnownB.isNonNegative ())
455+ return ;
456+
457+ const KnownBits KnownC = computeKnownBits (C, DemandedElts, Q, Depth + 1 );
458+ if (!KnownC.isNonNegative ())
459+ return ;
460+
461+ if (SubBC->getOpcode () == Instruction::Xor) {
462+ // If we matched subtraction as xor, we need to actually check that xor
463+ // is semantically equivalent to subtraction.
464+ //
465+ // For that to be true, b has to be a mask for c.
466+ // In known bits terms it would mean the following:
467+ //
468+ // - b is a constant
469+ if (!KnownB.isConstant ())
470+ return ;
471+
472+ // - b has ones at least in every position where c might have ones.
473+ const APInt MaxC = KnownC.getMaxValue ();
474+ if ((KnownB.getConstant () & MaxC) != MaxC)
475+ return ;
476+ }
477+
478+ // Compute max(a, d)
479+ const APInt MaxA = KnownA.getMaxValue ();
480+ const APInt MaxD = KnownD.getMaxValue ();
481+ const APInt MaxAD = APIntOps::umax (MaxA, MaxD);
482+
483+ // Compute max(a, d) * max(b)
484+ const APInt MaxB = KnownB.getMaxValue ();
485+ bool Overflow;
486+ const APInt UpperBound = MaxAD.umul_ov (MaxB, Overflow);
487+
488+ if (Overflow)
489+ return ;
490+
491+ // Count leading zeros in upper bound
492+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero ();
493+
494+ // Create KnownBits with only leading zeros set
495+ KnownOut.Zero .setHighBits (MinimumNumberOfLeadingZeros);
496+
497+ return ;
498+ }
499+
353500static void computeKnownBitsAddSub (bool Add, const Value *Op0, const Value *Op1,
354501 bool NSW, bool NUW,
355502 const APInt &DemandedElts,
@@ -369,6 +516,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
369516 isImpliedByDomCondition (ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI , Q.DL )
370517 .value_or (false ))
371518 KnownOut.makeNonNegative ();
519+
520+ if (Add)
521+ // Try to match lerp pattern and combine results
522+ computeKnownBitsFromLerpPattern (Op0, Op1, DemandedElts, KnownOut, Q, Depth);
372523}
373524
374525static void computeKnownBitsMul (const Value *Op0, const Value *Op1, bool NSW,
0 commit comments