Skip to content

Commit 4779aef

Browse files
committed
Test coverage for saturating shift + fix the bug that it turned up.
1 parent dbffff1 commit 4779aef

File tree

2 files changed

+93
-16
lines changed

2 files changed

+93
-16
lines changed

Sources/IntegerUtilities/SaturatingArithmetic.swift

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
extension FixedWidthInteger {
1313
@_transparent @usableFromInline
14-
var signExtension: Self { self &>> -1 }
14+
var sextOrZext: Self { self >> Self.bitWidth }
1515

1616
/// Saturating integer addition
1717
///
@@ -30,7 +30,7 @@ extension FixedWidthInteger {
3030
public func addingWithSaturation(_ other: Self) -> Self {
3131
let (wrapped, overflow) = addingReportingOverflow(other)
3232
if !overflow { return wrapped }
33-
return Self.max &- signExtension
33+
return Self.max &- sextOrZext
3434
}
3535

3636
/// Saturating integer subtraction
@@ -54,7 +54,7 @@ extension FixedWidthInteger {
5454
public func subtractingWithSaturation(_ other: Self) -> Self {
5555
let (wrapped, overflow) = subtractingReportingOverflow(other)
5656
if !overflow { return wrapped }
57-
return Self.max &- signExtension
57+
return Self.max &- sextOrZext
5858
}
5959

6060
/// Saturating integer negation
@@ -85,8 +85,8 @@ extension FixedWidthInteger {
8585
public func multipliedWithSaturation(by other: Self) -> Self {
8686
let (high, low) = multipliedFullWidth(by: other)
8787
let wrapped = Self(truncatingIfNeeded: low)
88-
if high == wrapped.signExtension { return wrapped }
89-
return Self.max &- high.signExtension
88+
if high == wrapped.sextOrZext { return wrapped }
89+
return Self.max &- high.sextOrZext
9090
}
9191

9292
/// Bitwise left with rounding and saturation.
@@ -106,23 +106,45 @@ extension FixedWidthInteger {
106106
leftBy count: Count, rounding rule: RoundingRule = .down
107107
) -> Self {
108108
// If count is zero or negative, negate it and do a right
109-
// shift without saturation instead, as that's easier.
109+
// shift without saturation instead, since we already have
110+
// that implemented.
110111
guard count > 0 else {
112+
// negating count is tricky, because count's type can be
113+
// an arbitrary BinaryInteger; in particular, it could be
114+
// .min of a signed type, so that its negation cannot be
115+
// represented in the same type. Fortunately, Int64 is
116+
// always big enough to represent arbitrary shifts of
117+
// arbitrary types, so we can use that as an intermediate
118+
// type, and then we can use negatedWithSaturation() to
119+
// handle the .min case.
120+
let int64Count = Int64(clamping: count)
111121
return shifted(
112-
rightBy: Self(clamping: count).negatedWithSaturation(),
122+
rightBy: int64Count.negatedWithSaturation(),
113123
rounding: rule
114124
)
115125
}
126+
let clamped = Self.max &- sextOrZext
116127
guard count < Self.bitWidth else {
117128
// If count is bitWidth or greater, we always overflow
118129
// unless self is zero.
119-
return self == 0 ? 0 : Self.max &- signExtension
130+
return self == 0 ? 0 : clamped
120131
}
121132
// Now we have 0 < count < bitWidth, so we can use a nice
122-
// straightforward implementation; the shift overflows if
123-
// the complementary shift doesn't match sign extension.
124-
let wrapped = self << count
125-
if self &>> ~count == signExtension { return wrapped }
126-
return Self.max &- signExtension
133+
// straightforward implementation; a shift overflows if
134+
// the complementary shift doesn't match sign-or-zero
135+
// extension. E.g.:
136+
//
137+
// - signed 0b0010_1111 << 2 overflows, because
138+
// 0b0010_1111 >> 5 is 0b0000_0001, which does not
139+
// equal 0b0000_0000
140+
//
141+
// - unsigned 0b0010_1111 << 2 does not overflow,
142+
// because 0b0010_0000 >> 6 is 0b0000_0000, which
143+
// does equal 0b0000_0000.
144+
let valueBits = Self.bitWidth &- (Self.isSigned ? 1 : 0)
145+
let wrapped = self &<< count
146+
let complement = valueBits &- Int(count)
147+
if self &>> complement == sextOrZext { return wrapped }
148+
return clamped
127149
}
128150
}

Tests/IntegerUtilitiesTests/ShiftTests.swift

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ import _TestSupport
1616

1717
final class IntegerUtilitiesShiftTests: XCTestCase {
1818

19-
func testRoundingShift<T: FixedWidthInteger>(
20-
_ value: T, _ count: Int, rounding rule: RoundingRule
21-
) {
19+
func testRoundingShift<T, C>(
20+
_ value: T, _ count: C, rounding rule: RoundingRule
21+
) where T: FixedWidthInteger, C: BinaryInteger {
2222
let floor = value >> count
2323
let lost = value &- floor << count
2424
let exact = count <= 0 || lost == 0
@@ -95,6 +95,10 @@ final class IntegerUtilitiesShiftTests: XCTestCase {
9595
testRoundingShift(T.random(in: .min ... .max), count, rounding: rule)
9696
}
9797
}
98+
99+
for count in Int8.min ... .max {
100+
testRoundingShift(T.random(in: .min ... .max), count, rounding: rule)
101+
}
98102
}
99103

100104
func testRoundingShifts() {
@@ -190,4 +194,55 @@ final class IntegerUtilitiesShiftTests: XCTestCase {
190194
testStochasticAverage(DoubleWidth<Int64>.random(in: .min ... .max))
191195
testStochasticAverage(DoubleWidth<UInt64>.random(in: .min ... .max))
192196
}
197+
198+
func testSaturatingShift<T, C>(
199+
_ value: T, _ count: C, rounding rule: RoundingRule
200+
) where T: FixedWidthInteger, C: FixedWidthInteger {
201+
let observed = value.shiftedWithSaturation(leftBy: count, rounding: rule)
202+
var expected: T = 0
203+
if count <= 0 {
204+
expected = value.shifted(rightBy: -Int64(count), rounding: rule)
205+
} else {
206+
let multiplier: T = 1 << count
207+
if multiplier <= 0 {
208+
expected = value == 0 ? 0 :
209+
value < 0 ? .min : .max
210+
} else {
211+
expected = value.multipliedWithSaturation(by: multiplier)
212+
}
213+
}
214+
if observed != expected {
215+
print("Error found in \(T.self).shiftedWithSaturation(leftBy: \(count), rounding: \(rule)).")
216+
print(" Value: \(String(value, radix: 16))")
217+
print("Expected: \(String(expected, radix: 16))")
218+
print("Observed: \(String(observed, radix: 16))")
219+
XCTFail()
220+
return
221+
}
222+
}
223+
224+
func testSaturatingShift<T: FixedWidthInteger>(
225+
_ type: T.Type, rounding rule: RoundingRule
226+
) {
227+
for count in Int8.min ... .max {
228+
testSaturatingShift(0, count, rounding: rule)
229+
for bits in 0 ..< T.bitWidth {
230+
let msb: T.Magnitude = 1 << bits
231+
let value = T(truncatingIfNeeded: msb) | .random(in: 0 ... T(msb-1))
232+
testSaturatingShift(value, count, rounding: rule)
233+
testSaturatingShift(0 &- value, count, rounding: rule)
234+
}
235+
}
236+
}
237+
238+
func testSaturatingShifts() {
239+
testSaturatingShift(Int8.self, rounding: .toOdd)
240+
testSaturatingShift(UInt8.self, rounding: .toOdd)
241+
testSaturatingShift(Int.self, rounding: .toOdd)
242+
testSaturatingShift(UInt.self, rounding: .toOdd)
243+
}
244+
245+
func testEdgeCaseForNegativeCount() {
246+
XCTAssertEqual(1.shiftedWithSaturation(leftBy: Int.min), 0)
247+
}
193248
}

0 commit comments

Comments
 (0)