Skip to content

Commit 877fd0a

Browse files
Merge pull request #289 from stephentyrone/rework-complex-division
Replaces the rescaling algorithm for Complex division to one inspired by Doug Priest's "Efficient Scaling for Complex Division," with some further tweaks to: - allow it to work for arbitrary FloatingPoint types, including Float16 - get exactly the same rounding behavior as the un-rescaled path, so that z/w = tz/tw when tz and tw are computed exactly. - allow future optimizations to hoist a rescaled reciprocal for more speedups. Unlike Priest, we do not try to avoid spurious overflow in the final computation when the result is very near the overflow boundary but cancellation brings us just inside it. We do not believe that this is a good tradeoff, as complex multiplication overflows in exactly the same way. We will investigate providing opt-in API to avoid this overflow case in a future PR.
2 parents ab63ebd + 262cbea commit 877fd0a

File tree

2 files changed

+144
-54
lines changed

2 files changed

+144
-54
lines changed

Sources/ComplexModule/Complex+AlgebraicField.swift

Lines changed: 78 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift Numerics open source project
44
//
5-
// Copyright (c) 2019-2021 Apple Inc. and the Swift Numerics project authors
5+
// Copyright (c) 2019-2024 Apple Inc. and the Swift Numerics project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -27,52 +27,93 @@ extension Complex: AlgebraicField {
2727
}
2828

2929
@_transparent
30-
public static func /(z: Complex, w: Complex) -> Complex {
31-
// Try the naive expression z/w = z*conj(w) / |w|^2; if we can compute
32-
// this without over/underflow, everything is fine and the result is
33-
// correct. If not, we have to rescale and do the computation carefully.
34-
let lenSq = w.lengthSquared
35-
guard lenSq.isNormal else { return rescaledDivide(z, w) }
36-
return z * (w.conjugate.divided(by: lenSq))
30+
public static func /=(z: inout Complex, w: Complex) {
31+
z = z / w
3732
}
3833

3934
@_transparent
40-
public static func /=(z: inout Complex, w: Complex) {
41-
z = z / w
35+
public static func /(z: Complex, w: Complex) -> Complex {
36+
// Try the naive expression z/w = z * (conj(w) / |w|^2); if we can
37+
// compute this without over/underflow, everything is fine and the
38+
// result is correct. If not, we have to rescale and do the
39+
// computation carefully (see below).
40+
let lenSq = w.lengthSquared
41+
guard lenSq.isNormal else { return rescaledDivide(z, w) }
42+
return z * w.conjugate.divided(by: lenSq)
4243
}
4344

4445
@usableFromInline @_alwaysEmitIntoClient @inline(never)
4546
internal static func rescaledDivide(_ z: Complex, _ w: Complex) -> Complex {
4647
if w.isZero { return .infinity }
47-
if z.isZero || !w.isFinite { return .zero }
48-
// TODO: detect when RealType is Float and just promote to Double, then
49-
// use the naive algorithm.
50-
let zScale = z.magnitude
51-
let wScale = w.magnitude
52-
let zNorm = z.divided(by: zScale)
53-
let wNorm = w.divided(by: wScale)
54-
let r = (zNorm * wNorm.conjugate).divided(by: wNorm.lengthSquared)
55-
// At this point, the result is (r * zScale)/wScale computed without
56-
// undue overflow or underflow. We know that r is close to unity, so
57-
// the question is simply what order in which to do this computation
58-
// to avoid spurious overflow or underflow. There are three options
59-
// to choose from:
48+
if !w.isFinite { return .zero }
49+
// Scaling algorithm adapted from Doug Priest's "Efficient Scaling for
50+
// Complex Division":
51+
if w.magnitude < .leastNormalMagnitude {
52+
// A difference from Priest's algorithm is that he didn't have to worry
53+
// about types like Float16, where the significand width is comparable
54+
// to the exponent range, such that |leastNormalMagnitude|^(-¾) isn't
55+
// representable (e.g. for Float16 it would want to be 2¹⁸, but the
56+
// largest allowed exponent is 15). Note that it's critical to use zʹ/wʹ
57+
// after rescaling to avoid this, rather than falling through into the
58+
// normal rescaling, because otherwise we might end up back in the
59+
// situation where |w| ~ 1.
60+
let s = 1/(RealType(RealType.radix) * .leastNormalMagnitude)
61+
let = w.multiplied(by: s)
62+
let = z.multiplied(by: s)
63+
return/
64+
}
65+
// Having handled that case, we proceed pretty similarly to Priest:
6066
//
61-
// - r * (zScale / wScale)
62-
// - (r * zScale) / wScale
63-
// - (r / wScale) * zScale
67+
// 1. Choose real scale s ~ |w|^(-¾), an exact power of the radix.
68+
// 2. wʹ ← sw
69+
// 3. zʹ ← sz
70+
// 4. return zʹ * (wʹ.conjugate / wʹ.lengthSquared) (i.e. zʹ/wʹ).
6471
//
65-
// The simplest case is when zScale / wScale is normal:
66-
if (zScale / wScale).isNormal {
67-
return r.multiplied(by: zScale / wScale)
68-
}
69-
// Otherwise, we need to compute either rNorm * zScale or rNorm / wScale
70-
// first. Choose the first if the first scaling behaves well, otherwise
71-
// choose the other one.
72-
if (r.magnitude * zScale).isNormal {
73-
return r.multiplied(by: zScale).divided(by: wScale)
74-
}
75-
return r.divided(by: wScale).multiplied(by: zScale)
72+
// Why is this safe and accurate? First, observe that wʹ and zʹ are both
73+
// computed exactly because:
74+
//
75+
// - s is an exact power of radix.
76+
// - wʹ ~ |w|^(¼), and hence cannot overflow or underflow.
77+
// - zʹ might overflow or underflow, but only if the final result also
78+
// overflows or underflows. (This is more subtle than I make it
79+
// sound. In particular, most of the fast ways one might try to
80+
// compute s give rise to a situation where when |w| is close to
81+
// one, multiplication by s is a dilation even though the actual
82+
// division is a contraction or vice-versa, and thus intermediate
83+
// computations might incorrectly overflow or underflow. Priest
84+
// had to take some care to avoid this situation, but we do not,
85+
// because we have already ruled out |w| ~ 1 before we call this
86+
// function.)
87+
//
88+
// Next observe that |wʹ.lengthSquared| ~ |w|^(½), so again this cannot
89+
// overflow or underflow, and neither can (wʹ.conjugate/wʹ.lengthSquared),
90+
// which has magnitude like |w|^(-¼).
91+
//
92+
// Note that because the scale factor is always a power of the radix,
93+
// the rescaling does not affect rounding, and so this algorithm is scale-
94+
// invariant compared to the mainline `/` implementation, up to the
95+
// underflow boundary.
96+
//
97+
// Note that our final assembly of the result is different from Priest;
98+
// he applies s to w twice, instead of once to w and once to z, and
99+
// does the product as (zw̅ʺ)*(1/|wʹ|²), while we do zʹ(w̅ʹ/|wʹ|²). We
100+
// prefer our version for three reasons:
101+
//
102+
// 1. it extracts a little more ILP
103+
// 2. it makes it so that we get exactly the same roundings on the
104+
// rescaled divide path as on the fast path, so that z/w = tz/tw
105+
// when tz and tw are computed exactly.
106+
// 3. it unlocks a future optimization where we hoist s and
107+
// (w̅ʹ/|wʹ|²) and make divisions all fast-path without perturbing
108+
// rounding.
109+
let s = RealType(
110+
sign: .plus,
111+
exponent: -3*w.magnitude.exponent/4,
112+
significand: 1
113+
)
114+
let = w.multiplied(by: s)
115+
let = z.multiplied(by: s)
116+
return*.conjugate.divided(by:.lengthSquared)
76117
}
77118

78119
/// A normalized complex number with the same phase as this value.

Tests/ComplexTests/ArithmeticTests.swift

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift Numerics open source project
44
//
5-
// Copyright (c) 2019 Apple Inc. and the Swift Numerics project authors
5+
// Copyright (c) 2019-2024 Apple Inc. and the Swift Numerics project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -13,11 +13,23 @@ import XCTest
1313
import ComplexModule
1414
import RealModule
1515

16+
func ulpsFromInfinity<T: Real>(_ a: T) -> T {
17+
(.greatestFiniteMagnitude - a) / .greatestFiniteMagnitude.ulp + 1
18+
}
19+
1620
// TODO: improve this to be a general-purpose complex comparison with tolerance
1721
func relativeError<T>(_ a: Complex<T>, _ b: Complex<T>) -> T {
1822
if a == b { return 0 }
19-
let scale = max(a.magnitude, b.magnitude, T.leastNormalMagnitude).ulp
20-
return (a - b).magnitude / scale
23+
if a.isFinite && b.isFinite {
24+
let scale = max(a.magnitude, b.magnitude, T.leastNormalMagnitude).ulp
25+
return (a - b).magnitude / scale
26+
} else {
27+
if a.isFinite {
28+
return ulpsFromInfinity(a.magnitude)
29+
} else {
30+
return ulpsFromInfinity(b.magnitude)
31+
}
32+
}
2133
}
2234

2335
func closeEnough<T: Real>(_ a: T, _ b: T, ulps allowed: T) -> Bool {
@@ -29,11 +41,15 @@ func checkMultiply<T>(
2941
_ a: Complex<T>, _ b: Complex<T>, expected: Complex<T>, ulps allowed: T
3042
) -> Bool {
3143
let observed = a*b
44+
if observed == expected { return false }
45+
// Even if the expected result is finite, we allow overflow if
46+
// the two-norm of the expected result overflows.
47+
if !observed.isFinite && !expected.length.isFinite { return false }
3248
let rel = relativeError(observed, expected)
33-
if rel > allowed {
49+
guard rel <= allowed else {
3450
print("Over-large error in \(a)*\(b)")
3551
print("Expected: \(expected)\nObserved: \(observed)")
36-
print("Relative error was \(rel) (tolerance: \(allowed).")
52+
print("Relative error was \(rel) (tolerance: \(allowed)).")
3753
return true
3854
}
3955
return false
@@ -43,11 +59,15 @@ func checkDivide<T>(
4359
_ a: Complex<T>, _ b: Complex<T>, expected: Complex<T>, ulps allowed: T
4460
) -> Bool {
4561
let observed = a/b
62+
if observed == expected { return false }
63+
// Even if the expected result is finite, we allow overflow if
64+
// the two-norm of the expected result overflows.
65+
if !observed.isFinite && !expected.length.isFinite { return false }
4666
let rel = relativeError(observed, expected)
47-
if rel > allowed {
67+
guard rel <= allowed else {
4868
print("Over-large error in \(a)/\(b)")
4969
print("Expected: \(expected)\nObserved: \(observed)")
50-
print("Relative error was \(rel) (tolerance: \(allowed).")
70+
print("Relative error was \(rel) (tolerance: \(allowed)).")
5171
return true
5272
}
5373
return false
@@ -63,7 +83,6 @@ final class ArithmeticTests: XCTestCase {
6383
func testPolar<T>(_ type: T.Type)
6484
where T: BinaryFloatingPoint, T: Real,
6585
T.Exponent: FixedWidthInteger, T.RawSignificand: FixedWidthInteger {
66-
6786
// In order to support round-tripping from rectangular to polar coordinate
6887
// systems, as a special case phase can be non-finite when length is
6988
// either zero or infinity.
@@ -76,10 +95,9 @@ final class ArithmeticTests: XCTestCase {
7695
XCTAssertEqual(Complex<T>(length:-.infinity, phase: .infinity), .infinity)
7796
XCTAssertEqual(Complex<T>(length:-.infinity, phase:-.infinity), .infinity)
7897
XCTAssertEqual(Complex<T>(length:-.infinity, phase: .nan ), .infinity)
79-
98+
8099
let exponentRange =
81-
(T.leastNormalMagnitude.exponent + T.Exponent(T.significandBitCount)) ...
82-
T.greatestFiniteMagnitude.exponent
100+
T.leastNormalMagnitude.exponent ... T.greatestFiniteMagnitude.exponent
83101
let inputs = (0..<100).map { _ in
84102
Polar(length: T(
85103
sign: .plus,
@@ -136,20 +154,29 @@ final class ArithmeticTests: XCTestCase {
136154
// Now test multiplication and division using the polar inputs:
137155
for q in inputs {
138156
let w = Complex(length: q.length, phase: q.phase)
139-
let product = Complex(length: p.length * q.length, phase: p.phase + q.phase)
157+
var product = Complex(length: p.length, phase: p.phase + q.phase)
158+
product.real *= q.length
159+
product.imaginary *= q.length
140160
if checkMultiply(z, w, expected: product, ulps: 16) { XCTFail() }
141-
let quotient = Complex(length: p.length / q.length, phase: p.phase - q.phase)
161+
var quotient = Complex(length: p.length, phase: p.phase - q.phase)
162+
quotient.real /= q.length
163+
quotient.imaginary /= q.length
142164
if checkDivide(z, w, expected: quotient, ulps: 16) { XCTFail() }
143165
}
144166
}
145167
}
146168

147169
func testPolar() {
170+
#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) && LONG_TESTS
171+
if #available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *) {
172+
testPolar(Float16.self)
173+
}
174+
#endif
148175
testPolar(Float.self)
149176
testPolar(Double.self)
150-
#if (arch(i386) || arch(x86_64)) && !os(Windows) && !os(Android)
177+
#if (arch(i386) || arch(x86_64)) && !os(Windows) && !os(Android)
151178
testPolar(Float80.self)
152-
#endif
179+
#endif
153180
}
154181

155182
func testBaudinSmith() {
@@ -191,16 +218,38 @@ final class ArithmeticTests: XCTestCase {
191218
Complex(1.02951151789360578e-84, 6.97145987515076231e-220)),
192219
]
193220
for test in vectors {
194-
if checkDivide(test.a, test.b, expected: test.c, ulps: 0.5) { XCTFail() }
221+
if checkDivide(test.a, test.b, expected: test.c, ulps: 1.0) { XCTFail() }
195222
if checkDivide(test.a, test.c, expected: test.b, ulps: 1.0) { XCTFail() }
196223
if checkMultiply(test.b, test.c, expected: test.a, ulps: 1.0) { XCTFail() }
197224
}
198225
}
199-
226+
200227
func testDivisionByZero() {
201228
XCTAssertFalse((Complex(0, 0) / Complex(0, 0)).isFinite)
202229
XCTAssertFalse((Complex(1, 1) / Complex(0, 0)).isFinite)
203230
XCTAssertFalse((Complex.infinity / Complex(0, 0)).isFinite)
204231
XCTAssertFalse((Complex.i / Complex(0, 0)).isFinite)
232+
233+
}
234+
235+
#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) && LONG_TESTS
236+
@available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
237+
func testFloat16DivisionSemiExhaustive() {
238+
func complex(bitPattern: UInt32) -> Complex<Float16> {
239+
Complex(
240+
Float16(bitPattern: UInt16(truncatingIfNeeded: bitPattern)),
241+
Float16(bitPattern: UInt16(truncatingIfNeeded: bitPattern >> 16))
242+
)
243+
}
244+
for bits in 0 ... UInt32.max {
245+
let a = complex(bitPattern: bits)
246+
if bits & 0xfffff == 0 { print(a) }
247+
let b = complex(bitPattern: UInt32.random(in: 0 ... .max))
248+
var q = Complex<Float>(a)/Complex<Float>(b)
249+
if checkDivide(a, b, expected: Complex<Float16>(q), ulps: 4) { XCTFail() }
250+
q = Complex<Float>(b)/Complex<Float>(a)
251+
if checkDivide(b, a, expected: Complex<Float16>(q), ulps: 4) { XCTFail() }
252+
}
205253
}
254+
#endif
206255
}

0 commit comments

Comments
 (0)