Skip to content

Commit fe15113

Browse files
committed
Further saturating arithmetic test and implementation cleanups
1 parent f00824e commit fe15113

File tree

4 files changed

+120
-34
lines changed

4 files changed

+120
-34
lines changed

Sources/IntegerUtilities/SaturatingArithmetic.swift

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,15 @@
1010
//===----------------------------------------------------------------------===//
1111

1212
extension FixedWidthInteger {
13-
@_transparent @usableFromInline
14-
var sextOrZext: Self { self >> Self.bitWidth }
13+
/// `~0` (all-ones) if this value is negative, otherwise `0`.
14+
///
15+
/// Note that if `Self` is unsigned, this always returns `0`,
16+
/// but it is useful for writing algorithms that are generic over
17+
/// signed and unsigned integers.
18+
@inline(__always) @usableFromInline
19+
var signbit: Self {
20+
return self < .zero ? ~.zero : .zero
21+
}
1522

1623
/// Saturating integer addition
1724
///
@@ -29,8 +36,7 @@ extension FixedWidthInteger {
2936
@inlinable
3037
public func addingWithSaturation(_ other: Self) -> Self {
3138
let (wrapped, overflow) = addingReportingOverflow(other)
32-
if !overflow { return wrapped }
33-
return Self.max &- sextOrZext
39+
return overflow ? Self.max &- signbit : wrapped
3440
}
3541

3642
/// Saturating integer subtraction
@@ -54,7 +60,7 @@ extension FixedWidthInteger {
5460
public func subtractingWithSaturation(_ other: Self) -> Self {
5561
let (wrapped, overflow) = subtractingReportingOverflow(other)
5662
if !overflow { return wrapped }
57-
return Self.isSigned ? Self.max &- sextOrZext : 0
63+
return Self.isSigned ? Self.max &- signbit : 0
5864
}
5965

6066
/// Saturating integer negation
@@ -85,10 +91,10 @@ extension FixedWidthInteger {
8591
public func multipliedWithSaturation(by other: Self) -> Self {
8692
let (high, low) = multipliedFullWidth(by: other)
8793
let wrapped = Self(truncatingIfNeeded: low)
88-
if high == wrapped.sextOrZext { return wrapped }
89-
return Self.max &- high.sextOrZext
94+
if high == wrapped.signbit { return wrapped }
95+
return Self.max &- high.signbit
9096
}
91-
97+
9298
/// Bitwise left with rounding and saturation.
9399
///
94100
/// `self` multiplied by the rational number 2^(`count`), saturated to the
@@ -102,28 +108,20 @@ extension FixedWidthInteger {
102108
/// and if negative a right shift.
103109
/// - rounding rule: the direction in which to round if `count` is negative.
104110
@inlinable
105-
public func shiftedWithSaturation<Count: BinaryInteger>(
106-
leftBy count: Count, rounding rule: RoundingRule = .down
111+
public func shiftedWithSaturation(
112+
leftBy count: Int,
113+
rounding rule: RoundingRule = .down
107114
) -> Self {
108-
// If count is zero or negative, negate it and do a right
109-
// shift without saturation instead, since we already have
110-
// that implemented.
115+
if count == 0 { return self }
116+
// If count is negative, negate it and do a right shift without
117+
// saturation instead, since we already have that implemented.
111118
guard count > 0 else {
112-
// negating count is tricky, because count's type can be
113-
// an arbitrary BinaryInteger; in particular, it could be
114-
// .min of a signed type, so that its negation cannot be
115-
// represented in the same type. Fortunately, Int64 is
116-
// always big enough to represent arbitrary shifts of
117-
// arbitrary types, so we can use that as an intermediate
118-
// type, and then we can use negatedWithSaturation() to
119-
// handle the .min case.
120-
let int64Count = Int64(clamping: count)
121119
return shifted(
122-
rightBy: int64Count.negatedWithSaturation(),
120+
rightBy: count.negatedWithSaturation(),
123121
rounding: rule
124122
)
125123
}
126-
let clamped = Self.max &- sextOrZext
124+
let clamped = Self.max &- signbit
127125
guard count < Self.bitWidth else {
128126
// If count is bitWidth or greater, we always overflow
129127
// unless self is zero.
@@ -143,7 +141,27 @@ extension FixedWidthInteger {
143141
// does equal 0b0000_0000.
144142
let valueBits = Self.bitWidth &- (Self.isSigned ? 1 : 0)
145143
let wrapped = self &<< count
146-
let complement = valueBits &- Int(count)
147-
return self &>> complement == sextOrZext ? wrapped : clamped
144+
let complement = valueBits &- count
145+
return self &>> complement == signbit ? wrapped : clamped
146+
}
147+
148+
/// Bitwise left with rounding and saturation.
149+
///
150+
/// `self` multiplied by the rational number 2^(`count`), saturated to the
151+
/// range `Self.min ... Self.max`, and rounded according to `rule`.
152+
///
153+
/// See `shifted(rightBy:rounding:)` for more discussion of rounding
154+
/// shifts with examples.
155+
///
156+
/// - Parameters:
157+
/// - leftBy count: the number of bits to shift by. If positive, this is a left-shift,
158+
/// and if negative a right shift.
159+
/// - rounding rule: the direction in which to round if `count` is negative.
160+
@_transparent
161+
public func shiftedWithSaturation(
162+
leftBy count: some BinaryInteger,
163+
rounding rule: RoundingRule = .down
164+
) -> Self {
165+
self.shiftedWithSaturation(leftBy: Int(clamping: count), rounding: rule)
148166
}
149167
}

Sources/IntegerUtilities/ShiftWithRounding.swift

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ extension BinaryInteger {
4242
/// a.shifted(rightBy: count, rounding: rule)
4343
/// a.divided(by: 1 << count, rounding: rule)
4444
@inlinable
45-
public func shifted<Count: BinaryInteger>(
46-
rightBy count: Count,
45+
public func shifted(
46+
rightBy count: Int,
4747
rounding rule: RoundingRule = .down
4848
) -> Self {
4949
// Easiest case: count is zero or negative, so shift is always exact;
@@ -61,7 +61,7 @@ extension BinaryInteger {
6161
// shifts by first shifting all but bitWidth - 1 bits with sticky
6262
// rounding, and then shifting the remaining bitWidth - 1 bits with
6363
// the desired rounding mode.
64-
let count = count - Count(bitWidth - 1)
64+
let count = count - (bitWidth - 1)
6565
let floor = self >> count
6666
let lost = self - (floor << count)
6767
let sticky = floor | (lost == 0 ? 0 : 1)
@@ -155,4 +155,43 @@ extension BinaryInteger {
155155
return floor
156156
}
157157
}
158+
159+
/// `self` divided by 2^(`count`), rounding the result according to `rule`.
160+
///
161+
/// The default rounding rule is `.down`, which matches the behavior of
162+
/// the `>>` operator from the standard library.
163+
///
164+
/// Some examples of different rounding rules:
165+
///
166+
/// // 3/2 is 1.5, which rounds (down by default) to 1.
167+
/// 3.shifted(rightBy: 1)
168+
///
169+
/// // 1.5 rounds up to 2.
170+
/// 3.shifted(rightBy: 1, rounding: .up)
171+
///
172+
/// // The two closest values are 1 and 2, 1 is returned because it
173+
/// // is odd.
174+
/// 3.shifted(rightBy: 1, rounding: .toOdd)
175+
///
176+
/// // 7/2^2 = 1.75, so the result is 1 with probability 1/4, and 2
177+
/// // with probability 3/4.
178+
/// 7.shifted(rightBy: 2, rounding: .stochastically)
179+
///
180+
/// // 4/2^2 = 4/4 = 1, exactly.
181+
/// 4.shifted(rightBy: 2, rounding: .trap)
182+
///
183+
/// // 5/2 is 2.5, which is not exact, so this traps.
184+
/// 5.shifted(rightBy: 1, rounding: .requireExact)
185+
///
186+
/// When `Self(1) << count` is positive, the following are equivalent:
187+
///
188+
/// a.shifted(rightBy: count, rounding: rule)
189+
/// a.divided(by: 1 << count, rounding: rule)
190+
@_transparent
191+
public func shifted(
192+
rightBy count: some BinaryInteger,
193+
rounding rule: RoundingRule = .down
194+
) -> Self {
195+
self.shifted(rightBy: Int(clamping: count), rounding: rule)
196+
}
158197
}

Tests/IntegerUtilitiesTests/SaturatingArithmeticTests.swift

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
3232
}
3333
}
3434

35-
func testSaturatingSubtractSigned() {
35+
func testSaturatingSubSigned() {
3636
for a in Int8.min ... Int8.max {
3737
for b in Int8.min ... Int8.max {
3838
let expected = Int8(clamping: Int16(a) - Int16(b))
@@ -48,7 +48,7 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
4848
}
4949
}
5050

51-
func testSaturatingNegation() {
51+
func testSaturatingNegSigned() {
5252
for a in Int8.min ... Int8.max {
5353
let expected = Int8(clamping: 0 - Int16(a))
5454
let observed = a.negatedWithSaturation()
@@ -62,7 +62,7 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
6262
}
6363
}
6464

65-
func testSaturatingMultiplicationSigned() {
65+
func testSaturatingMulSigned() {
6666
for a in Int8.min ... Int8.max {
6767
for b in Int8.min ... Int8.max {
6868
let expected = Int8(clamping: Int16(a) * Int16(b))
@@ -94,7 +94,7 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
9494
}
9595
}
9696

97-
func testSaturatingSubtractUnsigned() {
97+
func testSaturatingSubUnsigned() {
9898
for a in UInt8.min ... UInt8.max {
9999
for b in UInt8.min ... UInt8.max {
100100
let expected = UInt8(clamping: Int16(a) - Int16(b))
@@ -110,7 +110,20 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
110110
}
111111
}
112112

113-
func testSaturatingMultiplicationUnsigned() {
113+
func testSaturatingNegUnsigned() {
114+
for a in UInt8.min ... UInt8.max {
115+
let observed = a.negatedWithSaturation()
116+
if 0 != observed {
117+
print("Error found in (\(a)).negatedWithSaturation().")
118+
print("Expected: zero")
119+
print("Observed: \(String(observed, radix: 16))")
120+
XCTFail()
121+
return
122+
}
123+
}
124+
}
125+
126+
func testSaturatingMulUnsigned() {
114127
for a in UInt8.min ... UInt8.max {
115128
for b in UInt8.min ... UInt8.max {
116129
let expected = UInt8(clamping: UInt16(a) * UInt16(b))

Tests/WindowsMain.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,21 @@ extension IntegerUtilitiesRotateTests {
119119
])
120120
}
121121

122+
extension IntegerUtilitiesSaturatingTests {
123+
static var all = testCase([
124+
("testSaturatingAddSigned", IntegerUtilitiesSaturatingTests.testSaturatingAddSigned),
125+
("testSaturatingSubSigned", IntegerUtilitiesSaturatingTests.testSaturatingSubSigned),
126+
("testSaturatingNegSigned", IntegerUtilitiesSaturatingTests.testSaturatingNegSigned),
127+
("testSaturatingMulSigned", IntegerUtilitiesSaturatingTests.testSaturatingMulSigned),
128+
("testSaturatingAddUnsigned", IntegerUtilitiesSaturatingTests.testSaturatingAddUnsigned),
129+
("testSaturatingSubUnsigned", IntegerUtilitiesSaturatingTests.testSaturatingSubUnsigned),
130+
("testSaturatingNegUnsigned", IntegerUtilitiesSaturatingTests.testSaturatingNegUnsigned),
131+
("testSaturatingMulUnsigned", IntegerUtilitiesSaturatingTests.testSaturatingMulUnsigned),
132+
("testSaturatingShifts", IntegerUtilitiesSaturatingTests.testSaturatingShifts),
133+
("testEdgeCaseForNegativeCount", IntegerUtilitiesSaturatingTests.testEdgeCaseForNegativeCount)
134+
])
135+
}
136+
122137
extension IntegerUtilitiesShiftTests {
123138
static var all = testCase([
124139
("testRoundingShifts", IntegerUtilitiesShiftTests.testRoundingShifts),
@@ -170,6 +185,7 @@ var testCases = [
170185
IntegerUtilitiesGCDTests.all,
171186
IntegerUtilitiesRotateTests.all,
172187
IntegerUtilitiesShiftTests.all,
188+
IntegerUtilitiesSaturatingTests.all,
173189
IntegerUtilitiesTests.DoubleWidthTests.all,
174190
]
175191

0 commit comments

Comments
 (0)