Skip to content

Commit 97bab25

Browse files
committed
WIP
1 parent ab63ebd commit 97bab25

File tree

2 files changed

+203
-121
lines changed

2 files changed

+203
-121
lines changed

Sources/ComplexModule/Complex+AlgebraicField.swift

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,53 +26,96 @@ extension Complex: AlgebraicField {
2626
Complex(x, -y)
2727
}
2828

29+
@_transparent
30+
public static func /=(z: inout Complex, w: Complex) {
31+
z = z / w
32+
}
33+
2934
@_transparent
3035
public static func /(z: Complex, w: Complex) -> Complex {
31-
// Try the naive expression z/w = z*conj(w) / |w|^2; if we can compute
32-
// this without over/underflow, everything is fine and the result is
33-
// correct. If not, we have to rescale and do the computation carefully.
36+
// Try the naive expression z/w = z * (conj(w) / |w|^2); if we can
37+
// compute this without over/underflow, everything is fine and the
38+
// result is correct. If not, we have to rescale and do the
39+
// computation carefully (see below).
3440
let lenSq = w.lengthSquared
3541
guard lenSq.isNormal else { return rescaledDivide(z, w) }
3642
return z * (w.conjugate.divided(by: lenSq))
3743
}
3844

39-
@_transparent
40-
public static func /=(z: inout Complex, w: Complex) {
41-
z = z / w
42-
}
43-
44-
@usableFromInline @_alwaysEmitIntoClient @inline(never)
45+
@inline(never)
46+
@_specialize(exported: true, where RealType == Float)
47+
@_specialize(exported: true, where RealType == Double)
48+
@usableFromInline
4549
internal static func rescaledDivide(_ z: Complex, _ w: Complex) -> Complex {
4650
if w.isZero { return .infinity }
47-
if z.isZero || !w.isFinite { return .zero }
48-
// TODO: detect when RealType is Float and just promote to Double, then
49-
// use the naive algorithm.
50-
let zScale = z.magnitude
51-
let wScale = w.magnitude
52-
let zNorm = z.divided(by: zScale)
53-
let wNorm = w.divided(by: wScale)
54-
let r = (zNorm * wNorm.conjugate).divided(by: wNorm.lengthSquared)
55-
// At this point, the result is (r * zScale)/wScale computed without
56-
// undue overflow or underflow. We know that r is close to unity, so
57-
// the question is simply what order in which to do this computation
58-
// to avoid spurious overflow or underflow. There are three options
59-
// to choose from:
51+
if !w.isFinite { return .zero }
52+
// Scaling algorithm adapted from Doug Priest's "Efficient Scaling for
53+
// Complex Division":
6054
//
61-
// - r * (zScale / wScale)
62-
// - (r * zScale) / wScale
63-
// - (r / wScale) * zScale
55+
// 1. Choose real scale s ≅ |w|^(-¾), an exact power of the radix.
56+
// 2. wʹ ← sw
57+
// 3. zʹ ← sz
58+
// 4. return zʹ * (wʹ.conjugate / wʹ.lengthSquared)
6459
//
65-
// The simplest case is when zScale / wScale is normal:
66-
if (zScale / wScale).isNormal {
67-
return r.multiplied(by: zScale / wScale)
68-
}
69-
// Otherwise, we need to compute either rNorm * zScale or rNorm / wScale
70-
// first. Choose the first if the first scaling behaves well, otherwise
71-
// choose the other one.
72-
if (r.magnitude * zScale).isNormal {
73-
return r.multiplied(by: zScale).divided(by: wScale)
60+
// Why is this safe and accurate? First, observe that wʹ and zʹ are both
61+
// computed exactly because:
62+
//
63+
// - s is an exact power of radix.
64+
// - wʹ ~ |w|^(¼), and hence cannot overflow or underflow.
65+
// - zʹ can overflow or underflow, but only if the final result also
66+
// overflows or underflows (this is more subtle than it might
67+
// appear at first; Priest has to be very careful about it
68+
// because you get into trouble precisely in the case where
69+
// |w| is very close to 1. However, if we were in that case, we would
70+
// have just handled the division inline and never would have ended
71+
// up here.
72+
//
73+
// Next observe that |wʹ.lengthSquared| ~ |w|^(½), so again this cannot
74+
// overflow or underflow, and neither can
75+
// (wʹ.conjugate / wʹ.lengthSquared)
76+
77+
78+
// are of comparable
79+
// magnitude, and in particular the exponents of their magnitudes have the
80+
// same sign, so either both are a contraction or both are an expansion,
81+
// so any intermediate overflow or underflow is deserved.²
82+
//
83+
// Note that because the scale factor is always a power of the radix,
84+
// the rescaling does not affect rounding, and so this algorithm is scale-
85+
// invariant compared to the mainline `/` implementation, up to the
86+
// underflow boundary.
87+
//
88+
// ¹ This falls apart for formats where the number of significand bits is
89+
// comparable to the exponent range (in particular Float16), because then
90+
// the desired s is not representable. E.g. if w ~ .leastNonzeroMagnitude
91+
// in Float16 (0x1p-24), we want to have s = 0x1p18, which is outside the
92+
// range of representable values. This does not occur for any other types,
93+
// so we just carry a special-case implementation for Float16 to fix it.
94+
//
95+
// Priest never had to worry about this because Float16 didn't really exist
96+
// yet when he published and he was interested in double anyway.
97+
//
98+
// ² This WOULD NOT BE TRUE if we hadn't already handled well-scaled
99+
// divisors in the mainline path for the `/` operator above; it only
100+
// holds for sufficiently badly-scaled `w`. If the well-scaled cases
101+
// were not already eliminated, it would be possible to have |wʹ| a
102+
// little bigger than one and |wʺ| a bit smaller than one (or vice-versa), so
103+
// that intermediate undeserved overflow or underflow might occur. Priest
104+
// has to worry about this, but we do not.
105+
if w.magnitude < RealType.leastNormalMagnitude {
106+
let z = z.divided(by: RealType.leastNormalMagnitude)
107+
let w = w.divided(by: RealType.leastNormalMagnitude)
108+
return rescaledDivide(z, w)
74109
}
75-
return r.divided(by: wScale).multiplied(by: zScale)
110+
var exponent = -3 * w.magnitude.exponent / 4
111+
let s = RealType(
112+
sign: .plus,
113+
exponent: exponent,
114+
significand: 1
115+
)
116+
let = w.multiplied(by: s)
117+
let = z.multiplied(by: s)
118+
return/
76119
}
77120

78121
/// A normalized complex number with the same phase as this value.

Tests/ComplexTests/ArithmeticTests.swift

Lines changed: 125 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ func checkMultiply<T>(
3030
) -> Bool {
3131
let observed = a*b
3232
let rel = relativeError(observed, expected)
33-
if rel > allowed {
33+
guard rel <= allowed else {
3434
print("Over-large error in \(a)*\(b)")
3535
print("Expected: \(expected)\nObserved: \(observed)")
36-
print("Relative error was \(rel) (tolerance: \(allowed).")
36+
print("Relative error was \(rel) (tolerance: \(allowed)).")
3737
return true
3838
}
3939
return false
@@ -44,10 +44,10 @@ func checkDivide<T>(
4444
) -> Bool {
4545
let observed = a/b
4646
let rel = relativeError(observed, expected)
47-
if rel > allowed {
47+
guard rel <= allowed else {
4848
print("Over-large error in \(a)/\(b)")
4949
print("Expected: \(expected)\nObserved: \(observed)")
50-
print("Relative error was \(rel) (tolerance: \(allowed).")
50+
print("Relative error was \(rel) (tolerance: \(allowed)).")
5151
return true
5252
}
5353
return false
@@ -63,93 +63,99 @@ final class ArithmeticTests: XCTestCase {
6363
func testPolar<T>(_ type: T.Type)
6464
where T: BinaryFloatingPoint, T: Real,
6565
T.Exponent: FixedWidthInteger, T.RawSignificand: FixedWidthInteger {
66-
67-
// In order to support round-tripping from rectangular to polar coordinate
68-
// systems, as a special case phase can be non-finite when length is
69-
// either zero or infinity.
70-
XCTAssertEqual(Complex<T>(length: .zero, phase: .infinity), .zero)
71-
XCTAssertEqual(Complex<T>(length: .zero, phase:-.infinity), .zero)
72-
XCTAssertEqual(Complex<T>(length: .zero, phase: .nan ), .zero)
73-
XCTAssertEqual(Complex<T>(length: .infinity, phase: .infinity), .infinity)
74-
XCTAssertEqual(Complex<T>(length: .infinity, phase:-.infinity), .infinity)
75-
XCTAssertEqual(Complex<T>(length: .infinity, phase: .nan ), .infinity)
76-
XCTAssertEqual(Complex<T>(length:-.infinity, phase: .infinity), .infinity)
77-
XCTAssertEqual(Complex<T>(length:-.infinity, phase:-.infinity), .infinity)
78-
XCTAssertEqual(Complex<T>(length:-.infinity, phase: .nan ), .infinity)
7966

80-
let exponentRange =
81-
(T.leastNormalMagnitude.exponent + T.Exponent(T.significandBitCount)) ...
82-
T.greatestFiniteMagnitude.exponent
83-
let inputs = (0..<100).map { _ in
84-
Polar(length: T(
85-
sign: .plus,
86-
exponent: T.Exponent.random(in: exponentRange),
87-
significand: T.random(in: 1 ..< 2)
88-
), phase: T.random(in: -.pi ... .pi))
89-
}
90-
for p in inputs {
91-
// first test that each value can round-trip between rectangular and
92-
// polar coordinates with reasonable accuracy. We'll probably need to
93-
// relax this for some platforms (currently we're using the default
94-
// RNG, which means we don't get the same sequence of values each time;
95-
// this is good--more test coverage!--and bad, because without tight
96-
// bounds on every platform's libm, we can't get tight bounds on the
97-
// accuracy of these operations, so we need to relax them gradually).
98-
let z = Complex(length: p.length, phase: p.phase)
99-
if !closeEnough(z.length, p.length, ulps: 16) {
100-
print("p = \(p)\nz = \(z)\nz.length = \(z.length)")
101-
XCTFail()
102-
}
103-
if !closeEnough(z.phase, p.phase, ulps: 16) {
104-
print("p = \(p)\nz = \(z)\nz.phase = \(z.phase)")
105-
XCTFail()
106-
}
107-
// Complex(length: -r, phase: θ) = -Complex(length: r, phase: θ).
108-
let w = Complex(length: -p.length, phase: p.phase)
109-
if w != -z {
110-
print("p = \(p)\nw = \(w)\nz = \(z)")
111-
XCTFail()
112-
}
113-
XCTAssertEqual(w, -z)
114-
// if length*length is normal, it should be lengthSquared, up
115-
// to small error.
116-
if (p.length*p.length).isNormal {
117-
if !closeEnough(z.lengthSquared, p.length*p.length, ulps: 16) {
118-
print("p = \(p)\nz = \(z)\nz.lengthSquared = \(z.lengthSquared)")
119-
XCTFail()
120-
}
121-
}
122-
// Test reciprocal and normalized:
123-
let r = Complex(length: 1/p.length, phase: -p.phase)
124-
if r.isNormal {
125-
if relativeError(r, z.reciprocal!) > 16 {
126-
print("p = \(p)\nz = \(z)\nz.reciprocal = \(r)")
127-
XCTFail()
67+
// In order to support round-tripping from rectangular to polar coordinate
68+
// systems, as a special case phase can be non-finite when length is
69+
// either zero or infinity.
70+
XCTAssertEqual(Complex<T>(length: .zero, phase: .infinity), .zero)
71+
XCTAssertEqual(Complex<T>(length: .zero, phase:-.infinity), .zero)
72+
XCTAssertEqual(Complex<T>(length: .zero, phase: .nan ), .zero)
73+
XCTAssertEqual(Complex<T>(length: .infinity, phase: .infinity), .infinity)
74+
XCTAssertEqual(Complex<T>(length: .infinity, phase:-.infinity), .infinity)
75+
XCTAssertEqual(Complex<T>(length: .infinity, phase: .nan ), .infinity)
76+
XCTAssertEqual(Complex<T>(length:-.infinity, phase: .infinity), .infinity)
77+
XCTAssertEqual(Complex<T>(length:-.infinity, phase:-.infinity), .infinity)
78+
XCTAssertEqual(Complex<T>(length:-.infinity, phase: .nan ), .infinity)
79+
80+
let exponentRange =
81+
T.leastNormalMagnitude.exponent ... T.greatestFiniteMagnitude.exponent
82+
let inputs = (0..<100).map { _ in
83+
Polar(length: T(
84+
sign: .plus,
85+
exponent: T.Exponent.random(in: exponentRange),
86+
significand: T.random(in: 1 ..< 2)
87+
), phase: T.random(in: -.pi ... .pi))
88+
}
89+
for p in inputs {
90+
// first test that each value can round-trip between rectangular and
91+
// polar coordinates with reasonable accuracy. We'll probably need to
92+
// relax this for some platforms (currently we're using the default
93+
// RNG, which means we don't get the same sequence of values each time;
94+
// this is good--more test coverage!--and bad, because without tight
95+
// bounds on every platform's libm, we can't get tight bounds on the
96+
// accuracy of these operations, so we need to relax them gradually).
97+
let z = Complex(length: p.length, phase: p.phase)
98+
if !closeEnough(z.length, p.length, ulps: 16) {
99+
print("p = \(p)\nz = \(z)\nz.length = \(z.length)")
100+
XCTFail()
101+
}
102+
if !closeEnough(z.phase, p.phase, ulps: 16) {
103+
print("p = \(p)\nz = \(z)\nz.phase = \(z.phase)")
104+
XCTFail()
105+
}
106+
// Complex(length: -r, phase: θ) = -Complex(length: r, phase: θ).
107+
let w = Complex(length: -p.length, phase: p.phase)
108+
if w != -z {
109+
print("p = \(p)\nw = \(w)\nz = \(z)")
110+
XCTFail()
111+
}
112+
XCTAssertEqual(w, -z)
113+
// if length*length is normal, it should be lengthSquared, up
114+
// to small error.
115+
if (p.length*p.length).isNormal {
116+
if !closeEnough(z.lengthSquared, p.length*p.length, ulps: 16) {
117+
print("p = \(p)\nz = \(z)\nz.lengthSquared = \(z.lengthSquared)")
118+
XCTFail()
119+
}
120+
}
121+
// Test reciprocal and normalized:
122+
let r = Complex(length: 1/p.length, phase: -p.phase)
123+
if r.isNormal {
124+
if relativeError(r, z.reciprocal!) > 16 {
125+
print("p = \(p)\nz = \(z)\nz.reciprocal = \(r)")
126+
XCTFail()
127+
}
128+
} else { XCTAssertNil(z.reciprocal) }
129+
let n = Complex(length: 1, phase: p.phase)
130+
if relativeError(n, z.normalized!) > 16 {
131+
print("p = \(p)\nz = \(z)\nz.normalized = \(n)")
132+
XCTFail()
133+
}
134+
135+
// Now test multiplication and division using the polar inputs:
136+
for q in inputs {
137+
let w = Complex(length: q.length, phase: q.phase)
138+
var product = Complex(length: p.length, phase: p.phase + q.phase)
139+
product.real *= q.length
140+
product.imaginary *= q.length
141+
if checkMultiply(z, w, expected: product, ulps: 16) { XCTFail() }
142+
var quotient = Complex(length: p.length, phase: p.phase - q.phase)
143+
quotient.real /= q.length
144+
quotient.imaginary /= q.length
145+
if checkDivide(z, w, expected: quotient, ulps: 16) { XCTFail() }
146+
}
147+
}
128148
}
129-
} else { XCTAssertNil(z.reciprocal) }
130-
let n = Complex(length: 1, phase: p.phase)
131-
if relativeError(n, z.normalized!) > 16 {
132-
print("p = \(p)\nz = \(z)\nz.normalized = \(n)")
133-
XCTFail()
134-
}
135-
136-
// Now test multiplication and division using the polar inputs:
137-
for q in inputs {
138-
let w = Complex(length: q.length, phase: q.phase)
139-
let product = Complex(length: p.length * q.length, phase: p.phase + q.phase)
140-
if checkMultiply(z, w, expected: product, ulps: 16) { XCTFail() }
141-
let quotient = Complex(length: p.length / q.length, phase: p.phase - q.phase)
142-
if checkDivide(z, w, expected: quotient, ulps: 16) { XCTFail() }
143-
}
144-
}
145-
}
146149

147150
func testPolar() {
151+
#if (arch(arm64))
152+
// testPolar(Float16.self)
153+
#endif
148154
testPolar(Float.self)
149155
testPolar(Double.self)
150-
#if (arch(i386) || arch(x86_64)) && !os(Windows) && !os(Android)
156+
#if (arch(i386) || arch(x86_64)) && !os(Windows) && !os(Android)
151157
testPolar(Float80.self)
152-
#endif
158+
#endif
153159
}
154160

155161
func testBaudinSmith() {
@@ -191,16 +197,49 @@ final class ArithmeticTests: XCTestCase {
191197
Complex(1.02951151789360578e-84, 6.97145987515076231e-220)),
192198
]
193199
for test in vectors {
194-
if checkDivide(test.a, test.b, expected: test.c, ulps: 0.5) { XCTFail() }
200+
if checkDivide(test.a, test.b, expected: test.c, ulps: 1.0) { XCTFail() }
195201
if checkDivide(test.a, test.c, expected: test.b, ulps: 1.0) { XCTFail() }
196202
if checkMultiply(test.b, test.c, expected: test.a, ulps: 1.0) { XCTFail() }
197203
}
198204
}
199-
205+
200206
func testDivisionByZero() {
201207
XCTAssertFalse((Complex(0, 0) / Complex(0, 0)).isFinite)
202208
XCTAssertFalse((Complex(1, 1) / Complex(0, 0)).isFinite)
203209
XCTAssertFalse((Complex.infinity / Complex(0, 0)).isFinite)
204210
XCTAssertFalse((Complex.i / Complex(0, 0)).isFinite)
211+
212+
}
213+
214+
#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64))
215+
216+
/*
217+
@available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
218+
func testFloat16DivisionSemiExhaustive() {
219+
func complex(bitPattern: UInt32) -> Complex<Float16> {
220+
Complex(
221+
Float16(bitPattern: UInt16(truncatingIfNeeded: bitPattern)),
222+
Float16(bitPattern: UInt16(truncatingIfNeeded: bitPattern >> 16))
223+
)
224+
}
225+
for bits in 0 ... UInt32.max {
226+
let a = complex(bitPattern: bits)
227+
if bits & 0xfffff == 0 { print(a) }
228+
let b = complex(bitPattern: UInt32.random(in: 0 ... .max))
229+
var q = Complex<Float>(a)/Complex<Float>(b)
230+
if checkDivide(a, b, expected: Complex<Float16>(q), ulps: 32) { XCTFail() }
231+
q = Complex<Float>(b)/Complex<Float>(a)
232+
if checkDivide(b, a, expected: Complex<Float16>(q), ulps: 32) { XCTFail() }
233+
}
234+
}
235+
*/
236+
237+
@available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *)
238+
func testSpecificFloat16Value() {
239+
let a = Complex<Float16>(4.66, 3e-07)
240+
let b = Complex<Float16>(-4.32e-05, 4.977e-05)
241+
let q = a / b
242+
XCTAssertEqual(q, Complex<Float16>(-46368.0, -53376.0))
205243
}
244+
#endif
206245
}

0 commit comments

Comments
 (0)