@@ -13,23 +13,138 @@ function infer_eltype(itr)
13
13
ifelse (T2 != = Union{} && T2 <: T1 , T2, T1)
14
14
end
15
15
16
- struct SortedRandRangeIter {R}
16
+ struct SeqSampleIterWR {R}
17
17
rng:: R
18
- range :: UnitRange{ Int}
18
+ N :: Int
19
19
n:: Int
20
20
end
21
21
22
- @inline function Base. iterate (s:: SortedRandRangeIter )
23
- curmax = - log (Float64 (s. range . stop )) + randexp (s. rng)/ s. n
24
- return (s. range . stop - ceil (Int, exp (- curmax)) + 1 , (s. n- 1 , curmax))
22
+ @inline function Base. iterate (s:: SeqSampleIterWR )
23
+ curmax = - log (Float64 (s. N )) + randexp (s. rng)/ s. n
24
+ return (s. N - ceil (Int, exp (- curmax)) + 1 , (s. n- 1 , curmax))
25
25
end
26
- @inline function Base. iterate (s:: SortedRandRangeIter , state)
26
+ @inline function Base. iterate (s:: SeqSampleIterWR , state)
27
27
state[1 ] == 0 && return nothing
28
28
curmax = state[2 ] + randexp (s. rng)/ state[1 ]
29
- return (s. range . stop - ceil (Int, exp (- curmax)) + 1 , (state[1 ]- 1 , curmax))
29
+ return (s. N - ceil (Int, exp (- curmax)) + 1 , (state[1 ]- 1 , curmax))
30
30
end
31
31
32
- Base. IteratorEltype (:: SortedRandRangeIter ) = Base. HasEltype ()
33
- Base. eltype (:: SortedRandRangeIter ) = Int
34
- Base. IteratorSize (:: SortedRandRangeIter ) = Base. HasLength ()
35
- Base. length (s:: SortedRandRangeIter ) = s. n
32
+ Base. IteratorEltype (:: SeqSampleIterWR ) = Base. HasEltype ()
33
+ Base. eltype (:: SeqSampleIterWR ) = Int
34
+ Base. IteratorSize (:: SeqSampleIterWR ) = Base. HasLength ()
35
+ Base. length (s:: SeqSampleIterWR ) = s. n
36
+
37
+ struct SeqSampleIter{R}
38
+ rng:: R
39
+ N:: Int
40
+ n:: Int
41
+ alpha:: Float64
42
+ function SeqSampleIter (rng:: R , N, n) where R
43
+ alpha = 1 / 13
44
+ new {R} (rng, N, n, alpha)
45
+ end
46
+ end
47
+
48
+ @inline function Base. iterate (it:: SeqSampleIter )
49
+ i = 0
50
+ q1 = it. N - it. n + 1
51
+ q2 = q1 / it. N
52
+ vprime = exp (- randexp (it. rng)/ it. n)
53
+ threshold = it. alpha * it. n
54
+ s, vprime = skip (it. rng, it. n, it. N, vprime, q1, q2)
55
+ i, nv, Nv, q1, q2, threshold = new_state (it, s, i, it. n, it. N, q1, q2, threshold)
56
+ return (i, (i, nv, Nv, q1, q2, threshold, vprime))
57
+ end
58
+ @inline function Base. iterate (it:: SeqSampleIter , state)
59
+ i, nv, Nv, q1, q2, threshold, vprime = state
60
+ if nv > 1 && threshold < Nv
61
+ s, vprime = skip (it. rng, nv, Nv, vprime, q1, q2)
62
+ i, nv, Nv, q1, q2, threshold = new_state (it, s, i, nv, Nv, q1, q2, threshold)
63
+ return (i, (i, nv, Nv, q1, q2, threshold, vprime))
64
+ elseif nv > 1
65
+ s = seqsample_a (it. rng, it. N - i, nv)
66
+ nv -= 1
67
+ i += s+ 1
68
+ return (i, ((nv === 0 ? i : it. N+ 1 ), nv, Nv, q1, q2, threshold, vprime))
69
+ else
70
+ i === it. N+ 1 && return nothing
71
+ s = trunc (Int, Nv * vprime)
72
+ i += s+ 1
73
+ return (i, (it. N+ 1 , nv, Nv, q1, q2, threshold, vprime))
74
+ end
75
+ end
76
+
77
+ @inline function skip (rng, n, N, vprime, q1, q2)
78
+ local s
79
+ while true
80
+ local X
81
+ while true
82
+ X = N* (1 - vprime)
83
+ s = trunc (Int, X)
84
+ s < q1 && break
85
+ vprime = exp (- randexp (rng)/ n)
86
+ end
87
+
88
+ y = rand (rng)/ q2
89
+ lhs = exp (log (y)/ (n- 1 ))
90
+ rhs = ((q1- s)/ q1) * (N/ (N- X))
91
+
92
+ if lhs <= rhs
93
+ vprime = lhs/ rhs
94
+ break
95
+ end
96
+
97
+ if n- 1 > s
98
+ bottom = N- n
99
+ limit = N- s
100
+ else
101
+ bottom = N- s- 1
102
+ limit = q1
103
+ end
104
+
105
+ top = N- 1
106
+
107
+ while top >= limit
108
+ y *= top/ bottom
109
+ bottom -= 1
110
+ top -= 1
111
+ end
112
+
113
+ if log (y) < (n- 1 )* (log (N)- log (N- X))
114
+ vprime = exp (- randexp (rng)/ (n- 1 ))
115
+ break
116
+ end
117
+ vprime = exp (- randexp (rng)/ n)
118
+ end
119
+ return s, vprime
120
+ end
121
+
122
+ @inline function new_state (it, s, i, nv, Nv, q1, q2, threshold)
123
+ i += s+ 1
124
+ Nv -= s+ 1
125
+ nv -= 1
126
+ q1 -= s
127
+ q2 = q1/ Nv
128
+ threshold -= it. alpha
129
+ return i, nv, Nv, q1, q2, threshold
130
+ end
131
+
132
+ @inline function seqsample_a! (rng:: AbstractRNG , n, k)
133
+ if k > 1
134
+ i = 0
135
+ q = (n- k)/ n
136
+ while q > rand (rng)
137
+ i += 1
138
+ n -= 1
139
+ q *= (n- k)/ n
140
+ end
141
+ return i
142
+ else
143
+ return trunc (Int, n * rand (rng))
144
+ end
145
+ end
146
+
147
+ Base. IteratorEltype (:: SeqSampleIter ) = Base. HasEltype ()
148
+ Base. eltype (:: SeqSampleIter ) = Int
149
+ Base. IteratorSize (:: SeqSampleIter ) = Base. HasLength ()
150
+ Base. length (s:: SeqSampleIter ) = s. n
0 commit comments