@@ -766,32 +766,61 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
766766static KnownBits avgCompute (KnownBits LHS, KnownBits RHS, bool IsCeil,
767767 bool IsSigned) {
768768 unsigned BitWidth = LHS.getBitWidth ();
769- LHS = IsSigned ? LHS.sext (BitWidth + 1 ) : LHS.zext (BitWidth + 1 );
770- RHS = IsSigned ? RHS.sext (BitWidth + 1 ) : RHS.zext (BitWidth + 1 );
771- LHS =
772- computeForAddCarry (LHS, RHS, /* CarryZero*/ !IsCeil, /* CarryOne*/ IsCeil);
773- LHS = LHS.extractBits (BitWidth, 1 );
774- return LHS;
769+ KnownBits ExtLHS = IsSigned ? LHS.sext (BitWidth + 1 ) : LHS.zext (BitWidth + 1 );
770+ KnownBits ExtRHS = IsSigned ? RHS.sext (BitWidth + 1 ) : RHS.zext (BitWidth + 1 );
771+ KnownBits Res = computeForAddCarry (ExtLHS, ExtRHS, /* CarryZero=*/ !IsCeil,
772+ /* CarryOne=*/ IsCeil);
773+ Res = Res.extractBits (BitWidth, 1 );
774+
775+ // If we have only 1 known signbit between LHS/RHS we can try to figure
776+ // out result signbit.
777+ // NB: If we know both signbits `computeForAddCarry` gets the optimal result
778+ // already.
779+ if (IsSigned && Res.isSignUnknown () &&
780+ LHS.isSignUnknown () != RHS.isSignUnknown ()) {
781+ if (LHS.isSignUnknown ())
782+ std::swap (LHS, RHS);
783+ KnownBits UnsignedLHS = LHS;
784+ KnownBits UnsignedRHS = RHS;
785+ UnsignedLHS.One .clearSignBit ();
786+ UnsignedLHS.Zero .setSignBit ();
787+ UnsignedRHS.One .clearSignBit ();
788+ UnsignedRHS.Zero .setSignBit ();
789+ KnownBits ResOf =
790+ computeForAddCarry (UnsignedLHS, UnsignedRHS, /* CarryZero=*/ !IsCeil,
791+ /* CarryOne=*/ IsCeil);
792+ // Assuming no overflow (which is the case since we extend the addition when
793+ // taking the average):
794+ // Neg + Neg -> Neg
795+ // Neg + Pos -> Neg if the signbit doesn't overflow
796+ if (LHS.isNegative () && ResOf.isNonNegative ())
797+ Res.makeNegative ();
798+ // Pos + Pos -> Pos
799+ // Pos + Neg -> Pos if the signbit does overflow
800+ else if (LHS.isNonNegative () && ResOf.isNegative ())
801+ Res.makeNonNegative ();
802+ }
803+ return Res;
775804}
776805
777806KnownBits KnownBits::avgFloorS (const KnownBits &LHS, const KnownBits &RHS) {
778- return avgCompute (LHS, RHS, /* IsCeil */ false ,
779- /* IsSigned */ true );
807+ return avgCompute (LHS, RHS, /* IsCeil= */ false ,
808+ /* IsSigned= */ true );
780809}
781810
782811KnownBits KnownBits::avgFloorU (const KnownBits &LHS, const KnownBits &RHS) {
783- return avgCompute (LHS, RHS, /* IsCeil */ false ,
784- /* IsSigned */ false );
812+ return avgCompute (LHS, RHS, /* IsCeil= */ false ,
813+ /* IsSigned= */ false );
785814}
786815
787816KnownBits KnownBits::avgCeilS (const KnownBits &LHS, const KnownBits &RHS) {
788- return avgCompute (LHS, RHS, /* IsCeil */ true ,
789- /* IsSigned */ true );
817+ return avgCompute (LHS, RHS, /* IsCeil= */ true ,
818+ /* IsSigned= */ true );
790819}
791820
792821KnownBits KnownBits::avgCeilU (const KnownBits &LHS, const KnownBits &RHS) {
793- return avgCompute (LHS, RHS, /* IsCeil */ true ,
794- /* IsSigned */ false );
822+ return avgCompute (LHS, RHS, /* IsCeil= */ true ,
823+ /* IsSigned= */ false );
795824}
796825
797826KnownBits KnownBits::mul (const KnownBits &LHS, const KnownBits &RHS,
0 commit comments