Skip to content

Commit d6f1184

Browse files
BillyDonahueevergreen
authored andcommitted
SERVER-43032 simplify overflow_arithmetic.h
1 parent 0a0625f commit d6f1184

File tree

8 files changed

+60
-136
lines changed

8 files changed

+60
-136
lines changed

src/mongo/base/parse_number.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ inline StatusWith<uint64_t> parseMagnitudeFromStringWithBase(uint64_t base,
142142

143143
// This block is (n = (n * base) + digitValue) with overflow checking at each step.
144144
uint64_t multiplied;
145-
if (mongoUnsignedMultiplyOverflow64(n, base, &multiplied))
145+
if (overflow::mul(n, base, &multiplied))
146146
return Status(ErrorCodes::Overflow, "Overflow");
147-
if (mongoUnsignedAddOverflow64(multiplied, digitValue, &n))
147+
if (overflow::add(multiplied, digitValue, &n))
148148
return Status(ErrorCodes::Overflow, "Overflow");
149149
++charsConsumed;
150150
}

src/mongo/db/pipeline/document_source_sort.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,10 @@ Pipeline::SourceContainer::iterator DocumentSourceSort::doOptimizeAt(
130130

131131
// The skip and limit values can be very large, so we need to make sure the sum doesn't
132132
// overflow before applying an optimization to pull the limit into the sort stage.
133-
if (nextSkip && !mongoSignedAddOverflow64(skipSum, nextSkip->getSkip(), &safeSum)) {
133+
if (nextSkip && !overflow::add(skipSum, nextSkip->getSkip(), &safeSum)) {
134134
skipSum = safeSum;
135135
++stageItr;
136-
} else if (nextLimit &&
137-
!mongoSignedAddOverflow64(nextLimit->getLimit(), skipSum, &safeSum)) {
136+
} else if (nextLimit && !overflow::add(nextLimit->getLimit(), skipSum, &safeSum)) {
138137
_sortExecutor->setLimit(safeSum);
139138
container->erase(stageItr);
140139
stageItr = std::next(itr);

src/mongo/db/pipeline/expression.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2724,8 +2724,9 @@ Value ExpressionMultiply::evaluate(const Document& root, Variables* variables) c
27242724
decimalProduct = decimalProduct.multiply(val.coerceToDecimal());
27252725
} else {
27262726
doubleProduct *= val.coerceToDouble();
2727+
27272728
if (!std::isfinite(val.coerceToDouble()) ||
2728-
mongoSignedMultiplyOverflow64(longProduct, val.coerceToLong(), &longProduct)) {
2729+
overflow::mul(longProduct, val.coerceToLong(), &longProduct)) {
27292730
// The number is either Infinity or NaN, or the 'longProduct' would have
27302731
// overflowed, so we're abandoning it.
27312732
productType = NumberDouble;

src/mongo/platform/overflow_arithmetic.h

Lines changed: 33 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -35,123 +35,50 @@
3535
#include <SafeInt.hpp>
3636
#endif
3737

38-
namespace mongo {
38+
#include "mongo/stdx/type_traits.h"
3939

40-
/**
41-
* Returns true if multiplying lhs by rhs would overflow. Otherwise, multiplies 64-bit signed
42-
* or unsigned integers lhs by rhs and stores the result in *product.
43-
*/
44-
constexpr bool mongoSignedMultiplyOverflow64(int64_t lhs, int64_t rhs, int64_t* product);
45-
constexpr bool mongoUnsignedMultiplyOverflow64(uint64_t lhs, uint64_t rhs, uint64_t* product);
40+
namespace mongo::overflow {
4641

4742
/**
48-
* Returns true if adding lhs and rhs would overflow. Otherwise, adds 64-bit signed or unsigned
49-
* integers lhs and rhs and stores the result in *sum.
50-
*/
51-
constexpr bool mongoSignedAddOverflow64(int64_t lhs, int64_t rhs, int64_t* sum);
52-
constexpr bool mongoUnsignedAddOverflow64(uint64_t lhs, uint64_t rhs, uint64_t* sum);
53-
54-
/**
55-
* Returns true if subtracting rhs from lhs would overflow. Otherwise, subtracts 64-bit signed or
56-
* unsigned integers rhs from lhs and stores the result in *difference.
43+
* Synopsis:
44+
*
45+
* bool mul(A a, A b, T* r);
46+
* bool add(A a, A b, T* r);
47+
* bool sub(A a, A b, T* r);
48+
*
49+
* The domain type `A` evaluates to `T`, which is deduced from the `r` parameter.
50+
* That is, the input parameters are coerced into the type accepted by the output parameter.
51+
* All functions return true if operation would overflow, otherwise they store result in `*r`.
5752
*/
58-
constexpr bool mongoSignedSubtractOverflow64(int64_t lhs, int64_t rhs, int64_t* difference);
59-
constexpr bool mongoUnsignedSubtractOverflow64(uint64_t lhs, uint64_t rhs, uint64_t* difference);
6053

54+
// MSVC : The SafeInt functions return false on overflow.
55+
// GCC, Clang: The __builtin_*_overflow functions return true on overflow.
6156

57+
template <typename T>
58+
constexpr bool mul(stdx::type_identity_t<T> a, stdx::type_identity_t<T> b, T* r) {
6259
#ifdef _MSC_VER
63-
64-
// The SafeInt functions return true on success, false on overflow.
65-
66-
constexpr bool mongoSignedMultiplyOverflow64(int64_t lhs, int64_t rhs, int64_t* product) {
67-
return !SafeMultiply(lhs, rhs, *product);
68-
}
69-
70-
constexpr bool mongoUnsignedMultiplyOverflow64(uint64_t lhs, uint64_t rhs, uint64_t* product) {
71-
return !SafeMultiply(lhs, rhs, *product);
72-
}
73-
74-
constexpr bool mongoSignedAddOverflow64(int64_t lhs, int64_t rhs, int64_t* sum) {
75-
return !SafeAdd(lhs, rhs, *sum);
76-
}
77-
78-
constexpr bool mongoUnsignedAddOverflow64(uint64_t lhs, uint64_t rhs, uint64_t* sum) {
79-
return !SafeAdd(lhs, rhs, *sum);
80-
}
81-
82-
constexpr bool mongoSignedSubtractOverflow64(int64_t lhs, int64_t rhs, int64_t* difference) {
83-
return !SafeSubtract(lhs, rhs, *difference);
84-
}
85-
86-
constexpr bool mongoUnsignedSubtractOverflow64(uint64_t lhs, uint64_t rhs, uint64_t* difference) {
87-
return !SafeSubtract(lhs, rhs, *difference);
88-
}
89-
60+
return !SafeMultiply(a, b, *r);
9061
#else
91-
92-
// On GCC and CLANG we can use __builtin functions to perform these calculations. These return true
93-
// on overflow and false on success.
94-
95-
constexpr bool mongoSignedMultiplyOverflow64(long lhs, long rhs, long* product) {
96-
return __builtin_mul_overflow(lhs, rhs, product);
97-
}
98-
99-
constexpr bool mongoSignedMultiplyOverflow64(long long lhs, long long rhs, long long* product) {
100-
return __builtin_mul_overflow(lhs, rhs, product);
101-
}
102-
103-
constexpr bool mongoUnsignedMultiplyOverflow64(unsigned long lhs,
104-
unsigned long rhs,
105-
unsigned long* product) {
106-
return __builtin_mul_overflow(lhs, rhs, product);
107-
}
108-
109-
constexpr bool mongoUnsignedMultiplyOverflow64(unsigned long long lhs,
110-
unsigned long long rhs,
111-
unsigned long long* product) {
112-
return __builtin_mul_overflow(lhs, rhs, product);
113-
}
114-
115-
constexpr bool mongoSignedAddOverflow64(long lhs, long rhs, long* sum) {
116-
return __builtin_add_overflow(lhs, rhs, sum);
117-
}
118-
119-
constexpr bool mongoSignedAddOverflow64(long long lhs, long long rhs, long long* sum) {
120-
return __builtin_add_overflow(lhs, rhs, sum);
121-
}
122-
123-
constexpr bool mongoUnsignedAddOverflow64(unsigned long lhs,
124-
unsigned long rhs,
125-
unsigned long* sum) {
126-
return __builtin_add_overflow(lhs, rhs, sum);
127-
}
128-
129-
constexpr bool mongoUnsignedAddOverflow64(unsigned long long lhs,
130-
unsigned long long rhs,
131-
unsigned long long* sum) {
132-
return __builtin_add_overflow(lhs, rhs, sum);
133-
}
134-
135-
constexpr bool mongoSignedSubtractOverflow64(long lhs, long rhs, long* difference) {
136-
return __builtin_sub_overflow(lhs, rhs, difference);
137-
}
138-
139-
constexpr bool mongoSignedSubtractOverflow64(long long lhs, long long rhs, long long* difference) {
140-
return __builtin_sub_overflow(lhs, rhs, difference);
141-
}
142-
143-
constexpr bool mongoUnsignedSubtractOverflow64(unsigned long lhs,
144-
unsigned long rhs,
145-
unsigned long* sum) {
146-
return __builtin_sub_overflow(lhs, rhs, sum);
62+
return __builtin_mul_overflow(a, b, r);
63+
#endif
14764
}
14865

149-
constexpr bool mongoUnsignedSubtractOverflow64(unsigned long long lhs,
150-
unsigned long long rhs,
151-
unsigned long long* sum) {
152-
return __builtin_sub_overflow(lhs, rhs, sum);
66+
template <typename T>
67+
constexpr bool add(stdx::type_identity_t<T> a, stdx::type_identity_t<T> b, T* r) {
68+
#ifdef _MSC_VER
69+
return !SafeAdd(a, b, *r);
70+
#else
71+
return __builtin_add_overflow(a, b, r);
72+
#endif
15373
}
15474

75+
template <typename T>
76+
constexpr bool sub(stdx::type_identity_t<T> a, stdx::type_identity_t<T> b, T* r) {
77+
#ifdef _MSC_VER
78+
return !SafeSubtract(a, b, *r);
79+
#else
80+
return __builtin_sub_overflow(a, b, r);
15581
#endif
82+
}
15683

157-
} // namespace mongo
84+
} // namespace mongo::overflow

src/mongo/platform/overflow_arithmetic_test.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,34 +49,34 @@ namespace {
4949
} while (false)
5050

5151
#define assertSignedMultiplyNoOverflow(LHS, RHS, EXPECTED) \
52-
assertArithOverflow(int64_t, mongoSignedMultiplyOverflow64, LHS, RHS, false, EXPECTED)
52+
assertArithOverflow(int64_t, overflow::mul, LHS, RHS, false, EXPECTED)
5353
#define assertSignedMultiplyWithOverflow(LHS, RHS) \
54-
assertArithOverflow(int64_t, mongoSignedMultiplyOverflow64, LHS, RHS, true, 0)
54+
assertArithOverflow(int64_t, overflow::mul, LHS, RHS, true, 0)
5555

5656
#define assertUnsignedMultiplyNoOverflow(LHS, RHS, EXPECTED) \
57-
assertArithOverflow(uint64_t, mongoUnsignedMultiplyOverflow64, LHS, RHS, false, EXPECTED)
57+
assertArithOverflow(uint64_t, overflow::mul, LHS, RHS, false, EXPECTED)
5858
#define assertUnsignedMultiplyWithOverflow(LHS, RHS) \
59-
assertArithOverflow(uint64_t, mongoUnsignedMultiplyOverflow64, LHS, RHS, true, 0)
59+
assertArithOverflow(uint64_t, overflow::mul, LHS, RHS, true, 0)
6060

6161
#define assertSignedAddNoOverflow(LHS, RHS, EXPECTED) \
62-
assertArithOverflow(int64_t, mongoSignedAddOverflow64, LHS, RHS, false, EXPECTED)
62+
assertArithOverflow(int64_t, overflow::add, LHS, RHS, false, EXPECTED)
6363
#define assertSignedAddWithOverflow(LHS, RHS) \
64-
assertArithOverflow(int64_t, mongoSignedAddOverflow64, LHS, RHS, true, 0)
64+
assertArithOverflow(int64_t, overflow::add, LHS, RHS, true, 0)
6565

6666
#define assertUnsignedAddNoOverflow(LHS, RHS, EXPECTED) \
67-
assertArithOverflow(uint64_t, mongoUnsignedAddOverflow64, LHS, RHS, false, EXPECTED)
67+
assertArithOverflow(uint64_t, overflow::add, LHS, RHS, false, EXPECTED)
6868
#define assertUnsignedAddWithOverflow(LHS, RHS) \
69-
assertArithOverflow(uint64_t, mongoUnsignedAddOverflow64, LHS, RHS, true, 0)
69+
assertArithOverflow(uint64_t, overflow::add, LHS, RHS, true, 0)
7070

7171
#define assertSignedSubtractNoOverflow(LHS, RHS, EXPECTED) \
72-
assertArithOverflow(int64_t, mongoSignedSubtractOverflow64, LHS, RHS, false, EXPECTED)
72+
assertArithOverflow(int64_t, overflow::sub, LHS, RHS, false, EXPECTED)
7373
#define assertSignedSubtractWithOverflow(LHS, RHS) \
74-
assertArithOverflow(int64_t, mongoSignedSubtractOverflow64, LHS, RHS, true, 0)
74+
assertArithOverflow(int64_t, overflow::sub, LHS, RHS, true, 0)
7575

7676
#define assertUnsignedSubtractNoOverflow(LHS, RHS, EXPECTED) \
77-
assertArithOverflow(uint64_t, mongoUnsignedSubtractOverflow64, LHS, RHS, false, EXPECTED)
77+
assertArithOverflow(uint64_t, overflow::sub, LHS, RHS, false, EXPECTED)
7878
#define assertUnsignedSubtractWithOverflow(LHS, RHS) \
79-
assertArithOverflow(uint64_t, mongoUnsignedSubtractOverflow64, LHS, RHS, true, 0)
79+
assertArithOverflow(uint64_t, overflow::sub, LHS, RHS, true, 0)
8080

8181
TEST(OverflowArithmetic, SignedMultiplicationTests) {
8282
using limits = std::numeric_limits<int64_t>;

src/mongo/s/query/cluster_find.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ StatusWith<std::unique_ptr<QueryRequest>> transformQueryForShards(
9494
boost::optional<long long> newLimit;
9595
if (qr.getLimit()) {
9696
long long newLimitValue;
97-
if (mongoSignedAddOverflow64(*qr.getLimit(), qr.getSkip().value_or(0), &newLimitValue)) {
97+
if (overflow::add(*qr.getLimit(), qr.getSkip().value_or(0), &newLimitValue)) {
9898
return Status(
9999
ErrorCodes::Overflow,
100100
str::stream()
@@ -110,8 +110,7 @@ StatusWith<std::unique_ptr<QueryRequest>> transformQueryForShards(
110110
// !wantMore and ntoreturn mean the same as !wantMore and limit, so perform the conversion.
111111
if (!qr.wantMore()) {
112112
long long newLimitValue;
113-
if (mongoSignedAddOverflow64(
114-
*qr.getNToReturn(), qr.getSkip().value_or(0), &newLimitValue)) {
113+
if (overflow::add(*qr.getNToReturn(), qr.getSkip().value_or(0), &newLimitValue)) {
115114
return Status(ErrorCodes::Overflow,
116115
str::stream()
117116
<< "sum of ntoreturn and skip cannot be represented as a 64-bit "
@@ -121,8 +120,7 @@ StatusWith<std::unique_ptr<QueryRequest>> transformQueryForShards(
121120
newLimit = newLimitValue;
122121
} else {
123122
long long newNToReturnValue;
124-
if (mongoSignedAddOverflow64(
125-
*qr.getNToReturn(), qr.getSkip().value_or(0), &newNToReturnValue)) {
123+
if (overflow::add(*qr.getNToReturn(), qr.getSkip().value_or(0), &newNToReturnValue)) {
126124
return Status(ErrorCodes::Overflow,
127125
str::stream()
128126
<< "sum of ntoreturn and skip cannot be represented as a 64-bit "

src/mongo/util/duration.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ constexpr ToDuration duration_cast(const Duration<FromPeriod>& from) {
110110
typename ToDuration::rep toCount = 0;
111111
uassert(ErrorCodes::DurationOverflow,
112112
"Overflow casting from a lower-precision duration to a higher-precision duration",
113-
!mongoSignedMultiplyOverflow64(from.count(), FromOverTo::num, &toCount));
113+
!overflow::mul(from.count(), FromOverTo::num, &toCount));
114114
return ToDuration{toCount};
115115
}
116116
return ToDuration{from.count() / FromOverTo::den};
@@ -281,7 +281,7 @@ class Duration {
281281
}
282282
using OtherOverThis = std::ratio_divide<OtherPeriod, period>;
283283
rep otherCount;
284-
if (mongoSignedMultiplyOverflow64(other.count(), OtherOverThis::num, &otherCount)) {
284+
if (overflow::mul(other.count(), OtherOverThis::num, &otherCount)) {
285285
return other.count() < 0 ? 1 : -1;
286286
}
287287
if (count() < otherCount) {
@@ -329,14 +329,14 @@ class Duration {
329329
Duration& operator+=(const Duration& other) {
330330
uassert(ErrorCodes::DurationOverflow,
331331
str::stream() << "Overflow while adding " << other << " to " << *this,
332-
!mongoSignedAddOverflow64(count(), other.count(), &_count));
332+
!overflow::add(count(), other.count(), &_count));
333333
return *this;
334334
}
335335

336336
Duration& operator-=(const Duration& other) {
337337
uassert(ErrorCodes::DurationOverflow,
338338
str::stream() << "Overflow while subtracting " << other << " from " << *this,
339-
!mongoSignedSubtractOverflow64(count(), other.count(), &_count));
339+
!overflow::sub(count(), other.count(), &_count));
340340
return *this;
341341
}
342342

@@ -347,7 +347,7 @@ class Duration {
347347
"Durations may only be multiplied by values of signed integral type");
348348
uassert(ErrorCodes::DurationOverflow,
349349
str::stream() << "Overflow while multiplying " << *this << " by " << scale,
350-
!mongoSignedMultiplyOverflow64(count(), scale, &_count));
350+
!overflow::mul(count(), scale, &_count));
351351
return *this;
352352
}
353353

src/mongo/util/net/ssl_manager.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,7 @@ StatusWith<DERToken> DERToken::parse(ConstDataRange cdr, size_t* outLength) {
938938
const uint64_t tagAndLengthByteCount = kTagLength + encodedLengthBytesCount;
939939

940940
// This may overflow since derLength is from user data so check our arithmetic carefully.
941-
if (mongoUnsignedAddOverflow64(tagAndLengthByteCount, derLength, outLength) ||
942-
*outLength > cdr.length()) {
941+
if (overflow::add(tagAndLengthByteCount, derLength, outLength) || *outLength > cdr.length()) {
943942
return Status(ErrorCodes::InvalidSSLConfiguration, "Invalid DER length");
944943
}
945944

0 commit comments

Comments
 (0)