@@ -350,6 +350,152 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
350350 return V->getType ()->getScalarSizeInBits () - SignBits + 1 ;
351351}
352352
353+ // / Try to detect the lerp pattern: a * (b - c) + c * d
354+ // / where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355+ // /
356+ // / In that particular case, we can use the following chain of reasoning:
357+ // /
358+ // / a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359+ // /
360+ // / Since that is true for arbitrary a, b, c and d within our constraints, we
361+ // / can conclude that:
362+ // /
363+ // / max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364+ // /
365+ // / Considering that any result of the lerp would be less or equal to U, it
366+ // / would have at least the number of leading 0s as in U.
367+ // /
368+ // / While being quite a specific situation, it is fairly common in computer
369+ // / graphics in the shape of alpha blending.
370+ // /
371+ // / Modifies given KnownOut in-place with the inferred information.
372+ static void computeKnownBitsFromLerpPattern (const Value *Op0, const Value *Op1,
373+ const APInt &DemandedElts,
374+ KnownBits &KnownOut,
375+ const SimplifyQuery &Q,
376+ unsigned Depth) {
377+
378+ Type *Ty = Op0->getType ();
379+ const unsigned BitWidth = Ty->getScalarSizeInBits ();
380+
381+ // Only handle scalar types for now
382+ if (Ty->isVectorTy ())
383+ return ;
384+
385+ // Try to match: a * (b - c) + c * d.
386+ // When a == 1 => A == nullptr, the same applies to d/D as well.
387+ const Value *A = nullptr , *B = nullptr , *C = nullptr , *D = nullptr ;
388+ const BinaryOperator *SubBC = nullptr ;
389+
390+ const auto MatchSubBC = [&]() {
391+ // (b - c) can have two forms that interest us:
392+ //
393+ // 1. sub nuw %b, %c
394+ // 2. xor %c, %b
395+ //
396+ // For the first case, nuw flag guarantees our requirement b >= c.
397+ //
398+ // The second case might happen when the analysis can infer that b is a mask
399+ // for c and we can transform sub operation into xor (that is usually true
400+ // for constant b's). Even though xor is symmetrical, canonicalization
401+ // ensures that the constant will be the RHS. We have additional checks
402+ // later on to ensure that this xor operation is equivalent to subtraction.
403+ return m_CombineAnd (m_BinOp (SubBC),
404+ m_CombineOr (m_NUWSub (m_Value (B), m_Value (C)),
405+ m_Xor (m_Value (C), m_Value (B))));
406+ };
407+
408+ const auto MatchASubBC = [&]() {
409+ // Cases:
410+ // - a * (b - c)
411+ // - (b - c) * a
412+ // - (b - c) <- a implicitly equals 1
413+ return m_CombineOr (m_c_Mul (m_Value (A), MatchSubBC ()), MatchSubBC ());
414+ };
415+
416+ const auto MatchCD = [&]() {
417+ // Cases:
418+ // - d * c
419+ // - c * d
420+ // - c <- d implicitly equals 1
421+ return m_CombineOr (m_c_Mul (m_Value (D), m_Specific (C)), m_Specific (C));
422+ };
423+
424+ const auto Match = [&](const Value *LHS, const Value *RHS) {
425+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
426+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
427+ // has to evaluate first and return true.
428+ //
429+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
430+ return match (LHS, MatchASubBC ()) && match (RHS, MatchCD ());
431+ };
432+
433+ if (!Match (Op0, Op1) && !Match (Op1, Op0))
434+ return ;
435+
436+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
437+ // For some of the values we use the convention of leaving
438+ // it nullptr to signify an implicit constant 1.
439+ return V ? computeKnownBits (V, DemandedElts, Q, Depth + 1 )
440+ : KnownBits::makeConstant (APInt (BitWidth, 1 ));
441+ };
442+
443+ // Check that all operands are non-negative
444+ const KnownBits KnownA = ComputeKnownBitsOrOne (A);
445+ if (!KnownA.isNonNegative ())
446+ return ;
447+
448+ const KnownBits KnownD = ComputeKnownBitsOrOne (D);
449+ if (!KnownD.isNonNegative ())
450+ return ;
451+
452+ const KnownBits KnownB = computeKnownBits (B, DemandedElts, Q, Depth + 1 );
453+ if (!KnownB.isNonNegative ())
454+ return ;
455+
456+ const KnownBits KnownC = computeKnownBits (C, DemandedElts, Q, Depth + 1 );
457+ if (!KnownC.isNonNegative ())
458+ return ;
459+
460+ if (SubBC->getOpcode () == Instruction::Xor) {
461+ // If we matched subtraction as xor, we need to actually check that xor
462+ // is semantically equivalent to subtraction.
463+ //
464+ // For that to be true, b has to be a mask for c.
465+ // In known bits terms it would mean the following:
466+ //
467+ // - b is a constant
468+ if (!KnownB.isConstant ())
469+ return ;
470+
471+ // - b has ones at least in every position where c might have ones.
472+ const APInt MaxC = KnownC.getMaxValue ();
473+ if ((KnownB.getConstant () & MaxC) != MaxC)
474+ return ;
475+ }
476+
477+ // Compute max(a, d)
478+ const APInt MaxA = KnownA.getMaxValue ();
479+ const APInt MaxD = KnownD.getMaxValue ();
480+ const APInt MaxAD = APIntOps::umax (MaxA, MaxD);
481+
482+ // Compute max(a, d) * max(b)
483+ const APInt MaxB = KnownB.getMaxValue ();
484+ bool Overflow;
485+ const APInt UpperBound = MaxAD.umul_ov (MaxB, Overflow);
486+
487+ if (Overflow)
488+ return ;
489+
490+ // Count leading zeros in upper bound
491+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero ();
492+
493+ // Create KnownBits with only leading zeros set
494+ KnownOut.Zero .setHighBits (MinimumNumberOfLeadingZeros);
495+
496+ return ;
497+ }
498+
353499static void computeKnownBitsAddSub (bool Add, const Value *Op0, const Value *Op1,
354500 bool NSW, bool NUW,
355501 const APInt &DemandedElts,
@@ -369,6 +515,10 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
369515 isImpliedByDomCondition (ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI , Q.DL )
370516 .value_or (false ))
371517 KnownOut.makeNonNegative ();
518+
519+ if (Add)
520+ // Try to match lerp pattern and combine results
521+ computeKnownBitsFromLerpPattern (Op0, Op1, DemandedElts, KnownOut, Q, Depth);
372522}
373523
374524static void computeKnownBitsMul (const Value *Op0, const Value *Op1, bool NSW,
0 commit comments