Skip to content

Commit ddc8ce3

Browse files
committed
Make all rounding rules (from IntegerUtilities) available on FloatingPoint
IntegerUtilities has RoundingRule that is a superset of FloatingPointRoundingRule; many of these are also useful when working with floating-point. This change provides a `rounding(_:)` function similar to the standard libraries `rounded(_:)` that makes all of them available. It has a different name instead of being a shadow because otherwise existing use sites like `rounded(.down)` would become ambiguous.
1 parent b5adc8c commit ddc8ce3

File tree

2 files changed

+306
-1
lines changed

2 files changed

+306
-1
lines changed

Sources/IntegerUtilities/RoundingRule.swift

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,5 +256,67 @@ extension RoundingRule {
256256
/// > Deprecated: Use `.toNearestOrAway` instead.
257257
@inlinable
258258
@available(*, deprecated, renamed: "toNearestOrAway")
259-
static var toNearestOrAwayFromZero: Self { .toNearestOrAway }
259+
public static var toNearestOrAwayFromZero: Self { .toNearestOrAway }
260+
}
261+
262+
extension FloatingPoint {
263+
/// `self` rounded to integer according to `rule`.
264+
///
265+
/// This mirrors the standard library `rounded` API, providing access to
266+
/// the expanded set of rounding rules defined in IntegerUtilities. It is
267+
/// not just a shadow because that would lead to ambiguity errors in
268+
/// existing code that uses the shortened `rounded(.down)` form.
269+
@inlinable @inline(__always)
270+
public func rounding(_ rule: RoundingRule) -> Self {
271+
switch rule {
272+
case .down:
273+
return rounded(.down)
274+
case .up:
275+
return rounded(.up)
276+
case .towardZero:
277+
return rounded(.towardZero)
278+
case .awayFromZero:
279+
return rounded(.awayFromZero)
280+
case .toNearestOrDown:
281+
// FP doesn't have toNearestOrDown, so round toNearestOrEven and fixup
282+
// any exact-halfway cases.
283+
let nearest = rounded(.toNearestOrEven)
284+
return nearest - self == 1/2 ? rounded(.down) : nearest
285+
case .toNearestOrUp:
286+
// FP doesn't have toNearestOrUp, so round toNearestOrEven and fixup
287+
// any exact-halfway cases.
288+
let nearest = rounded(.toNearestOrEven)
289+
return self - nearest == 1/2 ? rounded(.up) : nearest
290+
case .toNearestOrZero:
291+
// FP doesn't have toNearestOrZero, so round toNearestOrEven and fixup
292+
// any exact-halfway cases.
293+
let nearest = rounded(.toNearestOrEven)
294+
return (self - nearest).magnitude == 1/2 ? rounded(.towardZero) : nearest
295+
case .toNearestOrAway:
296+
return self.rounded(.toNearestOrAwayFromZero)
297+
case .toNearestOrEven:
298+
return self.rounded(.toNearestOrEven)
299+
case .toOdd:
300+
let trunc = rounded(.towardZero)
301+
if trunc == self { return trunc }
302+
let one = Self(signOf: self, magnitudeOf: 1)
303+
// We have eliminated all large values at this point; add ±0.5, and see
304+
// which way that rounds, then select the other value.
305+
let even = (trunc + one/2).rounded(.toNearestOrEven)
306+
return trunc == even ? trunc + one : trunc
307+
case .stochastically:
308+
let trunc = rounded(.towardZero)
309+
if trunc == self { return trunc }
310+
// We have eliminated all large values at this point; add dither in
311+
// ±[0,1) and then truncate.
312+
let bits = Swift.min(-Self.ulpOfOne.exponent, 32)
313+
let random = Self(UInt32.random(in: 0 ... (1 << bits &- 1)))
314+
let dither = Self(sign: sign, exponent: -bits, significand: random)
315+
return (self + dither).rounded(.towardZero)
316+
case .requireExact:
317+
let trunc = rounded(.towardZero)
318+
precondition(isInfinite || trunc == self, "\(self) is not an exact integer.")
319+
return self
320+
}
321+
}
260322
}
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
//===--- RoundingTests.swift ----------------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift Numerics open source project
4+
//
5+
// Copyright (c) 2024 Apple Inc. and the Swift Numerics project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import IntegerUtilities
14+
import _TestSupport
15+
import XCTest
16+
17+
final class IntegerUtilitiesRoundingTests: XCTestCase {
18+
func testRoundingDirected<T: BinaryFloatingPoint>(_ type: T.Type) {
19+
let inf = T.infinity
20+
let gfm = T.greatestFiniteMagnitude
21+
let big: T = 1 / .ulpOfOne
22+
let two: T = 2
23+
let threeHalves: T = 1.5
24+
let one: T = 1
25+
let half: T = 0.5
26+
let nrm = T.leastNormalMagnitude
27+
let lnm = T.leastNonzeroMagnitude
28+
let vectors: [(input: T, down: T, up: T, zero: T, away: T)] = [
29+
(-inf, -inf, -inf, -inf, -inf),
30+
(-gfm, -gfm, -gfm, -gfm, -gfm),
31+
(-big.nextUp, -big-1, -big-1, -big-1, -big-1),
32+
(-big, -big, -big, -big, -big),
33+
(-big.nextDown, -big, -big+1, -big+1, -big),
34+
(-two, -two, -two, -two, -two),
35+
(-two.nextDown, -two, -one, -one, -two),
36+
(-threeHalves.nextUp, -two, -one, -one, -two),
37+
(-threeHalves, -two, -one, -one, -two),
38+
(-threeHalves.nextDown, -two, -one, -one, -two),
39+
(-one.nextUp, -two, -one, -one, -two),
40+
(-one, -one, -one, -one, -one),
41+
(-one.nextDown, -one, 0, 0, -one),
42+
(-half.nextUp, -one, 0, 0, -one),
43+
(-half, -one, 0, 0, -one),
44+
(-half.nextDown, -one, 0, 0, -one),
45+
(-nrm, -one, 0, 0, -one),
46+
(-lnm, -one, 0, 0, -one),
47+
(-.zero, 0, 0, 0, 0),
48+
( .zero, 0, 0, 0, 0),
49+
( lnm, 0, one, 0, one),
50+
( nrm, 0, one, 0, one),
51+
( half.nextDown, 0, one, 0, one),
52+
( half, 0, one, 0, one),
53+
( half.nextUp, 0, one, 0, one),
54+
( one.nextDown, 0, one, 0, one),
55+
( one, one, one, one, one),
56+
( one.nextUp, one, two, one, two),
57+
( threeHalves.nextDown, one, two, one, two),
58+
( threeHalves, one, two, one, two),
59+
( threeHalves.nextUp, one, two, one, two),
60+
( two.nextDown, one, two, one, two),
61+
( two, two, two, two, two),
62+
( big.nextDown, big-1, big, big-1, big),
63+
( big, big, big, big, big),
64+
( big.nextUp, big+1, big+1, big+1, big+1),
65+
( gfm, gfm, gfm, gfm, gfm),
66+
( inf, inf, inf, inf, inf),
67+
]
68+
for vector in vectors {
69+
70+
XCTAssertEqual(vector.input.rounding(.down), vector.down)
71+
if vector.down == 0 {
72+
XCTAssertEqual(vector.input.rounding(.down).sign, vector.input.sign)
73+
}
74+
75+
XCTAssertEqual(vector.input.rounding(.up), vector.up)
76+
if vector.up == 0 {
77+
XCTAssertEqual(vector.input.rounding(.up).sign, vector.input.sign)
78+
}
79+
80+
XCTAssertEqual(vector.input.rounding(.towardZero), vector.zero)
81+
if vector.zero == 0 {
82+
XCTAssertEqual(vector.input.rounding(.towardZero).sign, vector.input.sign)
83+
}
84+
85+
XCTAssertEqual(vector.input.rounding(.awayFromZero), vector.away)
86+
if vector.away == 0 {
87+
XCTAssertEqual(vector.input.rounding(.awayFromZero).sign, vector.input.sign)
88+
}
89+
}
90+
}
91+
92+
func testRoundingDirected() {
93+
testRoundingDirected(Float.self)
94+
testRoundingDirected(Double.self)
95+
}
96+
97+
func testRoundingNearest<T: BinaryFloatingPoint>(_ type: T.Type) {
98+
let inf = T.infinity
99+
let gfm = T.greatestFiniteMagnitude
100+
let big: T = 1 / .ulpOfOne
101+
let two: T = 2
102+
let threeHalves: T = 1.5
103+
let one: T = 1
104+
let half: T = 0.5
105+
let nrm = T.leastNormalMagnitude
106+
let lnm = T.leastNonzeroMagnitude
107+
let vectors: [(input: T, down: T, up: T, zero: T, away: T, even: T)] = [
108+
(-inf, -inf, -inf, -inf, -inf, -inf),
109+
(-gfm, -gfm, -gfm, -gfm, -gfm, -gfm),
110+
(-big.nextUp, -big-1, -big-1, -big-1, -big-1, -big-1),
111+
(-big, -big, -big, -big, -big, -big),
112+
(-big.nextDown, -big, -big+1, -big+1, -big, -big),
113+
(-two, -two, -two, -two, -two, -two),
114+
(-two.nextDown, -two, -two, -two, -two, -two),
115+
(-threeHalves.nextUp, -two, -two, -two, -two, -two),
116+
(-threeHalves, -two, -one, -one, -two, -two),
117+
(-threeHalves.nextDown, -one, -one, -one, -one, -one),
118+
(-one.nextUp, -one, -one, -one, -one, -one),
119+
(-one, -one, -one, -one, -one, -one),
120+
(-one.nextDown, -one, -one, -one, -one, -one),
121+
(-half.nextUp, -one, -one, -one, -one, -one),
122+
(-half, -one, 0, 0, -one, 0),
123+
(-half.nextDown, 0, 0, 0, 0, 0),
124+
(-nrm, 0, 0, 0, 0, 0),
125+
(-lnm, 0, 0, 0, 0, 0),
126+
(-.zero, 0, 0, 0, 0, 0),
127+
( .zero, 0, 0, 0, 0, 0),
128+
( lnm, 0, 0, 0, 0, 0),
129+
( nrm, 0, 0, 0, 0, 0),
130+
( half.nextDown, 0, 0, 0, 0, 0),
131+
( half, 0, one, 0, one, 0),
132+
( half.nextUp, one, one, one, one, one),
133+
( one.nextDown, one, one, one, one, one),
134+
( one, one, one, one, one, one),
135+
( one.nextUp, one, one, one, one, one),
136+
( threeHalves.nextDown, one, one, one, one, one),
137+
( threeHalves, one, two, one, two, two),
138+
( threeHalves.nextUp, two, two, two, two, two),
139+
( two.nextDown, two, two, two, two, two),
140+
( two, two, two, two, two, two),
141+
( big.nextDown, big-1, big, big-1, big, big),
142+
( big, big, big, big, big, big),
143+
( big.nextUp, big+1, big+1, big+1, big+1, big+1),
144+
( gfm, gfm, gfm, gfm, gfm, gfm),
145+
( inf, inf, inf, inf, inf, inf),
146+
]
147+
for vector in vectors {
148+
149+
XCTAssertEqual(vector.input.rounding(.toNearestOrDown), vector.down, "\(vector.input).rounding(.toNearestOrDown)")
150+
if vector.down == 0 {
151+
XCTAssertEqual(vector.input.rounding(.toNearestOrDown).sign, vector.input.sign, "\(vector.input).rounding(.toNearestOrDown)")
152+
}
153+
154+
XCTAssertEqual(vector.input.rounding(.toNearestOrUp), vector.up, "\(vector.input).rounding(.toNearestOrUp)")
155+
if vector.up == 0 {
156+
XCTAssertEqual(vector.input.rounding(.toNearestOrUp).sign, vector.input.sign, "\(vector.input).rounding(.toNearestOrUp)")
157+
}
158+
159+
XCTAssertEqual(vector.input.rounding(.toNearestOrZero), vector.zero, "\(vector.input).rounding(.toNearestOrZero)")
160+
if vector.zero == 0 {
161+
XCTAssertEqual(vector.input.rounding(.toNearestOrZero).sign, vector.input.sign, "\(vector.input).rounding(.toNearestOrZero)")
162+
}
163+
164+
XCTAssertEqual(vector.input.rounding(.toNearestOrAway), vector.away, "\(vector.input).rounding(.toNearestOrAway)")
165+
if vector.away == 0 {
166+
XCTAssertEqual(vector.input.rounding(.toNearestOrAway).sign, vector.input.sign, "\(vector.input).rounding(.toNearestOrAway)")
167+
}
168+
169+
XCTAssertEqual(vector.input.rounding(.toNearestOrEven), vector.even, "\(vector.input).rounding(.toNearestOrEven)")
170+
if vector.even == 0 {
171+
XCTAssertEqual(vector.input.rounding(.toNearestOrEven).sign, vector.input.sign, "\(vector.input).rounding(.toNearestOrEven)")
172+
}
173+
}
174+
}
175+
176+
func testRoundingNearest() {
177+
testRoundingNearest(Float.self)
178+
testRoundingNearest(Double.self)
179+
}
180+
181+
func testRoundingOdd<T: BinaryFloatingPoint>(_ type: T.Type) {
182+
let inf = T.infinity
183+
let gfm = T.greatestFiniteMagnitude
184+
let big: T = 1 / .ulpOfOne
185+
let two: T = 2
186+
let threeHalves: T = 1.5
187+
let one: T = 1
188+
let half: T = 0.5
189+
let nrm = T.leastNormalMagnitude
190+
let lnm = T.leastNonzeroMagnitude
191+
let vectors: [(input: T, odd: T)] = [
192+
(-inf, -inf),
193+
(-gfm, -gfm),
194+
(-big.nextUp, -big-1),
195+
(-big, -big),
196+
(-big.nextDown, -big+1),
197+
(-two, -two),
198+
(-two.nextDown, -one),
199+
(-threeHalves.nextUp, -one),
200+
(-threeHalves, -one),
201+
(-threeHalves.nextDown, -one),
202+
(-one.nextUp, -one),
203+
(-one, -one),
204+
(-one.nextDown, -one),
205+
(-half.nextUp, -one),
206+
(-half, -one),
207+
(-half.nextDown, -one),
208+
(-nrm, -one),
209+
(-lnm, -one),
210+
(-.zero, 0),
211+
( .zero, 0),
212+
( lnm, one),
213+
( nrm, one),
214+
( half.nextDown, one),
215+
( half, one),
216+
( half.nextUp, one),
217+
( one.nextDown, one),
218+
( one, one),
219+
( one.nextUp, one),
220+
( threeHalves.nextDown, one),
221+
( threeHalves, one),
222+
( threeHalves.nextUp, one),
223+
( two.nextDown, one),
224+
( two, two),
225+
( big.nextDown, big-1),
226+
( big, big),
227+
( big.nextUp, big+1),
228+
( gfm, gfm),
229+
( inf, inf)
230+
]
231+
for vector in vectors {
232+
XCTAssertEqual(vector.input.rounding(.toOdd), vector.odd, "\(vector.input).rounding(.toOdd)")
233+
if vector.odd == 0 {
234+
XCTAssertEqual(vector.input.rounding(.toOdd).sign, vector.input.sign, "\(vector.input).rounding(.toOdd)")
235+
}
236+
}
237+
}
238+
239+
func testRoundingOdd() {
240+
testRoundingOdd(Float.self)
241+
testRoundingOdd(Double.self)
242+
}
243+
}

0 commit comments

Comments
 (0)