@@ -2,10 +2,11 @@ import RealModule
22
33/// Binomial distribution implementation with optimizations for large n values
44@frozen public struct Binomial {
5- private static var iterations : Int { 200 }
5+ @inlinable static var thresholdNormal : Double { 10_000 }
6+ @inlinable static var thresholdBTPE : Double { 30 }
7+ @inlinable static var thresholdRare : Double { 0.05 }
68
7- /// TODO: fine tune
8- @inlinable static var normalApproximationThreshold : Double { 100_000 }
9+ private static var iterations : Int { 200 }
910
1011 public let n : Int64
1112 public let p : Double
@@ -20,23 +21,107 @@ extension Binomial {
2021}
2122extension Binomial {
2223 @inlinable public func sample( using generator: inout some RandomNumberGenerator ) -> Int64 {
23- self . sample { . random( in: 0 ... 1 , using: & generator) }
24- }
25-
26- /// Sample from a binomial distribution using inverse transform sampling.
27- @inlinable public func sample( U: ( ) -> Double ) -> Int64 {
2824 if self . p <= 0 { return 0 }
2925 if self . p >= 1 { return self . n }
26+ if self . n <= 0 { return 0 }
3027
28+ let n : Double = Double . init ( self . n)
29+ let μ : Double = n * self . p
3130 let q : Double = 1 - self . p
32- let u : Double = U ( )
33-
34- // B(n, p) = n – B(n, 1 – p)
35- return self . p < 0.5
36- ? Self . cdfInverse ( u: u, n: self . n, p: self . p, q: q)
37- : self . n - Self. cdfInverse ( u: u, n: self . n, p: q, q: self . p)
31+ let σ ²: Double = μ * q
32+ if σ² > Self . thresholdNormal {
33+ let σ : Double = . sqrt( σ²)
34+ let u : Double = . random( in: 0 ... 1 , using: & generator)
35+ let z : Double = Normal . cdfInverse ( u)
36+ let x : Double = ( μ + z * σ) . rounded ( )
37+ if x >= n {
38+ return self . n
39+ } else if
40+ x <= 0 {
41+ return 0
42+ } else {
43+ return Int64 . init ( x)
44+ }
45+ } else if q < self . p {
46+ let m : Int64
47+ if σ² >= Self . thresholdBTPE {
48+ m = Self . sampleBTPE (
49+ n: self . n,
50+ μ: n * q,
51+ σ: . sqrt( σ²) ,
52+ p: q,
53+ q: self . p,
54+ using: & generator
55+ )
56+ } else if q < Self . thresholdRare {
57+ m = Self . sampleGeometric ( n: self . n, p: q, using: & generator)
58+ } else {
59+ m = Self . cdfInverse (
60+ n: self . n,
61+ μ: n * q,
62+ σ²: σ²,
63+ p: q,
64+ q: self . p,
65+ u: . random( in: 0 ... 1 , using: & generator) ,
66+ )
67+ }
68+ return self . n - m
69+ } else {
70+ let m : Int64
71+ if σ² >= Self . thresholdBTPE {
72+ m = Self . sampleBTPE (
73+ n: self . n,
74+ μ: μ,
75+ σ: . sqrt( σ²) ,
76+ p: self . p,
77+ q: q,
78+ using: & generator
79+ )
80+ } else if p < Self . thresholdRare {
81+ m = Self . sampleGeometric ( n: self . n, p: self . p, using: & generator)
82+ } else {
83+ m = Self . cdfInverse (
84+ n: self . n,
85+ μ: μ,
86+ σ²: σ²,
87+ p: self . p,
88+ q: q,
89+ u: . random( in: 0 ... 1 , using: & generator) ,
90+ )
91+ }
92+ return m
93+ }
3894 }
39-
95+ }
96+ extension Binomial {
97+ // geometric jumps
98+ @inlinable static func sampleGeometric(
99+ n: Int64 ,
100+ p: Double ,
101+ using generator: inout some RandomNumberGenerator
102+ ) -> Int64 {
103+ let scale : Double = 1 / Double. log ( onePlus: - p)
104+
105+ var successes : Int64 = 0
106+ var remaining : Int64 = n
107+
108+ repeat {
109+ let u : Double = . random( in: 0 ... 1 , using: & generator)
110+ // calculate number of failures before the next success
111+ // this number may be very large, so it should not be cast to `Int64` eagerly
112+ let jump : Double = Double . log ( u) * scale
113+ if jump >= Double . init ( remaining) {
114+ break
115+ } else {
116+ successes += 1
117+ remaining -= 1
118+ remaining -= Int64 . init ( jump)
119+ }
120+ } while remaining > 0
121+ return successes
122+ }
123+ }
124+ extension Binomial {
40125 // Theoretical binomial probability.
41126 @inlinable public func pdf( _ k: Int64 ) -> Double {
42127 if self . p <= 0 {
@@ -61,33 +146,165 @@ extension Binomial {
61146 return Double . exp ( nCk + k * Double. log ( self . p) + l * Double. log ( q) )
62147 }
63148}
149+ extension Binomial {
150+ /// Executes the BTPE (Binomial, Triangle, Parallelogram, Exponential) Algorithm.
151+ /// Guarantees exact statistical accuracy in O(1) expected time for variance >= 30.
152+ @inlinable static func sampleBTPE(
153+ n: Int64 ,
154+ μ: Double ,
155+ σ: Double ,
156+ p: Double ,
157+ q: Double ,
158+ using generator: inout some RandomNumberGenerator
159+ ) -> Int64 {
160+ /// continuous mode
161+ let peak : Double = μ + p
162+ /// discrete mode
163+ let mode : Double = peak. rounded ( . down)
164+ let width : Double = Double . init ( Int64 . init ( 2.195 * σ - 4.6 * q) ) + 0.5
165+
166+ /// defines the horizontal dimensions of the triangular region, and the two
167+ /// parallelograms stacked above it on either side
168+ let envelope : ( l: Double , center: Double , r: Double )
169+
170+ envelope. center = mode + 0.5
171+ envelope. l = envelope. center - width
172+ envelope. r = envelope. center + width
173+
174+ /// dictates the vertical height of the parallelogram (region 2) that sits on top of
175+ /// the triangle (region 1), the exact formula is a mathematically derived upper bound
176+ /// created by Kachitvichyanukul and Schmeiser
177+ let c : Double = 0.134 + 20.5 / ( 15.3 + mode)
178+
179+ /// tangent slopes of the exponential tails
180+ let slope : ( l: Double , r: Double ) = (
181+ l: ( peak - envelope. l) / ( peak - envelope. l * p) ,
182+ r: ( envelope. r - peak) / ( envelope. r * q)
183+ )
184+ let λ : ( l: Double , r: Double ) = (
185+ l: slope. l * ( 1 + 0.5 * slope. l) ,
186+ r: slope. r * ( 1 + 0.5 * slope. r)
187+ )
188+
189+ let area : ( Double , Double , total: Double )
190+ // area of the triangle, plus the two parallelograms (looks like a house)
191+ area. 0 = width * ( 1 + 2 * c)
192+ // area of the triangle, plus parallelograms, plus the left tail
193+ area. 1 = area. 0 + c / λ. l
194+ // area of the triangle, plus parallelograms, plus both tails
195+ area. 2 = area. 1 + c / λ. r
196+
197+ var logCache : ( scale: Double , odds: Double ) ? = nil
198+ while true {
199+ /// v in the range (0, 1] to avoid log(0)
200+ let v : Double = 1 - Double. random ( in: 0 ..< 1 , using: & generator)
201+ let u : Double = area. total * Double. random ( in: 0 ..< 1 , using: & generator)
202+
203+ let k : Int64
204+ let y : Double
205+
206+ if u <= width {
207+ // region 1: triangle, automatically accepted, point generated by
208+ // transforming uniform random point into triangle
209+ return Int64 . init ( envelope. center - width * v + u)
210+ } else if u <= area. 0 {
211+ // region 2: parallelograms
212+ let x : Double = envelope. l + ( u - width) / c
213+ if x < 0 {
214+ continue
215+ }
216+
217+ k = Int64 . init ( x)
218+
219+ guard k <= n else {
220+ // this point won’t possibly be accepted
221+ continue
222+ }
223+
224+ /// this is the height of the triangle, to which a random uniform offset is
225+ /// added to generate a point in the parallelogram
226+ let h : Double = 1 - abs( envelope. center - x) / width
227+ y = h + v * c
228+
229+ guard y > 0 else {
230+ continue
231+ }
232+ } else if u <= area. 1 {
233+ // region 3: left exponential tail
234+ let x : Double = envelope. l + Double. log ( v) / λ. l
235+ if x < 0 {
236+ continue
237+ }
238+
239+ k = Int64 . init ( x)
240+ y = v * ( u - area. 0 ) * λ. l
241+ } else {
242+ // region 4: right exponential tail
243+ let x : Double = envelope. r - Double. log ( v) / λ. r
244+ if x >= Double . init ( Int64 . max) {
245+ continue
246+ }
247+
248+ k = Int64 . init ( x)
249+
250+ guard k <= n else {
251+ continue
252+ }
253+
254+ y = v * ( u - area. 1 ) * λ. r
255+ }
256+
257+ // Compares the generated point mathematically against the true Binomial probability
258+ let x : Double = Double . init ( k)
259+ let n : Double = Double . init ( n)
260+ // note that there is sometimes a “squeeze test” that appears here, as was written
261+ // in the original paper, but it was later revealed to be incorrect
262+ let log : ( scale: Double , odds: Double )
263+ if let logCache: ( scale: Double , odds: Double ) {
264+ log = logCache
265+ } else {
266+ /// these are heavy computations, and they are only used 20 to 25 percent of the
267+ /// time, so we compute them lazily and then cache the result for later
268+ let success : Double = . logGamma( mode + 1 )
269+ let failure : Double = . logGamma( n - mode + 1 )
270+ log = ( scale: success + failure, odds: Double . log ( p / q) )
271+ logCache = log
272+ }
273+
274+ let pdf : Double = log. scale
275+ + ( x - mode) * log. odds
276+ - Double. logGamma ( x + 1 )
277+ - Double. logGamma ( n - x + 1 )
278+
279+ if pdf >= Double . log ( y) {
280+ return k
281+ }
282+ }
283+ }
284+ }
64285extension Binomial {
65286 /// Find the binomial value using binary search on the CDF
66- /// For large n, uses direct normal approximation for significant performance improvement
67287 @usableFromInline static func cdfInverse(
68- u: Double ,
69288 n: Int64 ,
289+ μ: Double ,
290+ σ²: Double ,
70291 p: Double ,
71- q: Double
292+ q: Double ,
293+ u: Double ,
72294 ) -> Int64 {
73295 let n : ( i: Int64 , f: Double ) = ( n, Double . init ( n) )
74296
75- // Fast path for extreme cases
76- if u <= Double . pow ( q, n. f) { return 0 }
77- if u >= 1 - Double. pow ( p, n. f) { return n. i }
78-
79- // Get approximate starting point using normal approximation
80- let μ : Double = n. f * p
81- let σ : Double = Double . sqrt ( μ * q)
82-
83297 // Use quantile function of normal distribution
84- let z : Double = Normal . cdfInverse ( u)
85- let guess : Int64 = min ( max ( 0 , Int64 . init ( ( μ + z * σ) . rounded ( ) ) ) , n. i)
298+ let guess : Int64
86299
87- // For very large n, if n*p*q > threshold, we can use the normal approximation directly
88- // This is a significant optimization for large n values!
89- if μ * q > Self . normalApproximationThreshold {
90- return guess
300+ let z : Double = Normal . cdfInverse ( u)
301+ let x : Double = ( μ + z * Double . sqrt( σ²) ) . rounded ( )
302+ if x >= n. f {
303+ guess = n. i
304+ } else if x <= 0 {
305+ guess = 0
306+ } else {
307+ guess = Int64 . init ( x)
91308 }
92309
93310 // For smaller n, continue with binary search for greater accuracy
0 commit comments