Skip to content

Commit 8c6afce

Browse files
committed
Algorithm D
1 parent 544bfdc commit 8c6afce

File tree

5 files changed

+148
-24
lines changed

5 files changed

+148
-24
lines changed

docs/src/index.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ you can simply use the `fit!` function to update the reservoir:
2929
```julia
3030
julia> using StreamSampling
3131

32+
julia> st = 1:100;
33+
3234
julia> rs = ReservoirSample{Int}(5);
3335

34-
julia> for x in 1:100
36+
julia> for x in st
3537
fit!(rs, x)
3638
end
3739

@@ -50,11 +52,11 @@ also possible to iterate over a `StreamSample` like so
5052
```julia
5153
julia> using StreamSampling
5254

53-
julia> iter = 1:100;
55+
julia> st = 1:100;
5456

55-
julia> ss = StreamSample{Int}(iter, 5, 100);
57+
julia> ss = StreamSample{Int}(st, 5, 100);
5658

57-
julia> r = Int[];
59+
julia> m = Int[];
5860

5961
julia> for x in ss
6062
push!(r, x)
@@ -69,4 +71,8 @@ julia> r
6971
75
7072
```
7173

72-
Consult the [API page](https://juliadynamics.github.io/StreamSampling.jl/stable/api) for more information about the package interface.
74+
This has the advantage to require `O(1)` memory, while reservoir sample techniques requires `O(k)` memory where `k`
75+
is the number of elements in the sample.
76+
77+
Consult the [API page](https://juliadynamics.github.io/StreamSampling.jl/stable/api) for more information about the
78+
package interface.

src/SamplingInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function OnlineStatsBase.merge(::AbstractReservoirSample)
100100
end
101101

102102
"""
103-
StreamSample{T}([rng], iter, n, [N], method = AlgS())
103+
StreamSample{T}([rng], iter, n, [N], method = AlgD())
104104
105105
Initializes a stream sample, which can then be iterated over
106106
to return the sampling elements of the iterable `iter` which
@@ -183,7 +183,7 @@ Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, me
183183
s = ReservoirSample{iter_type}(rng, n, method, ImmutSample(), ordered ? Ord() : Unord())
184184
return update_all!(s, iter, ordered)
185185
else
186-
m = method isa AlgL || method isa AlgR || method isa AlgORDS ? AlgORDS() : AlgORDSWR()
186+
m = method isa AlgL || method isa AlgR || method isa AlgD ? AlgD() : AlgORDSWR()
187187
s = collect(StreamSample{iter_type}(rng, iter, n, length(iter), m))
188188
return ordered ? s : shuffle!(rng, s)
189189
end

src/SamplingUtils.jl

Lines changed: 126 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,138 @@ function infer_eltype(itr)
1313
ifelse(T2 !== Union{} && T2 <: T1, T2, T1)
1414
end
1515

16-
struct SortedRandRangeIter{R}
16+
struct SeqSampleIterWR{R}
1717
rng::R
18-
range::UnitRange{Int}
18+
N::Int
1919
n::Int
2020
end
2121

22-
@inline function Base.iterate(s::SortedRandRangeIter)
23-
curmax = -log(Float64(s.range.stop)) + randexp(s.rng)/s.n
24-
return (s.range.stop - ceil(Int, exp(-curmax)) + 1, (s.n-1, curmax))
22+
@inline function Base.iterate(s::SeqSampleIterWR)
23+
curmax = -log(Float64(s.N)) + randexp(s.rng)/s.n
24+
return (s.N - ceil(Int, exp(-curmax)) + 1, (s.n-1, curmax))
2525
end
26-
@inline function Base.iterate(s::SortedRandRangeIter, state)
26+
@inline function Base.iterate(s::SeqSampleIterWR, state)
2727
state[1] == 0 && return nothing
2828
curmax = state[2] + randexp(s.rng)/state[1]
29-
return (s.range.stop - ceil(Int, exp(-curmax)) + 1, (state[1]-1, curmax))
29+
return (s.N - ceil(Int, exp(-curmax)) + 1, (state[1]-1, curmax))
3030
end
3131

32-
Base.IteratorEltype(::SortedRandRangeIter) = Base.HasEltype()
33-
Base.eltype(::SortedRandRangeIter) = Int
34-
Base.IteratorSize(::SortedRandRangeIter) = Base.HasLength()
35-
Base.length(s::SortedRandRangeIter) = s.n
32+
Base.IteratorEltype(::SeqSampleIterWR) = Base.HasEltype()
33+
Base.eltype(::SeqSampleIterWR) = Int
34+
Base.IteratorSize(::SeqSampleIterWR) = Base.HasLength()
35+
Base.length(s::SeqSampleIterWR) = s.n
36+
37+
struct SeqSampleIter{R}
38+
rng::R
39+
N::Int
40+
n::Int
41+
alpha::Float64
42+
function SeqSampleIter(rng::R, N, n) where R
43+
alpha = 1/13
44+
new{R}(rng, N, n, alpha)
45+
end
46+
end
47+
48+
@inline function Base.iterate(it::SeqSampleIter)
49+
i = 0
50+
q1 = it.N - it.n + 1
51+
q2 = q1 / it.N
52+
vprime = exp(-randexp(it.rng)/it.n)
53+
threshold = it.alpha * it.n
54+
s, vprime = skip(it.rng, it.n, it.N, vprime, q1, q2)
55+
i, nv, Nv, q1, q2, threshold = new_state(it, s, i, it.n, it.N, q1, q2, threshold)
56+
return (i, (i, nv, Nv, q1, q2, threshold, vprime))
57+
end
58+
@inline function Base.iterate(it::SeqSampleIter, state)
59+
i, nv, Nv, q1, q2, threshold, vprime = state
60+
if nv > 1 && threshold < Nv
61+
s, vprime = skip(it.rng, nv, Nv, vprime, q1, q2)
62+
i, nv, Nv, q1, q2, threshold = new_state(it, s, i, nv, Nv, q1, q2, threshold)
63+
return (i, (i, nv, Nv, q1, q2, threshold, vprime))
64+
elseif nv > 1
65+
s = seqsample_a(it.rng, it.N - i, nv)
66+
nv -= 1
67+
i += s+1
68+
return (i, ((nv === 0 ? i : it.N+1), nv, Nv, q1, q2, threshold, vprime))
69+
else
70+
i === it.N+1 && return nothing
71+
s = trunc(Int, Nv * vprime)
72+
i += s+1
73+
return (i, (it.N+1, nv, Nv, q1, q2, threshold, vprime))
74+
end
75+
end
76+
77+
@inline function skip(rng, n, N, vprime, q1, q2)
78+
local s
79+
while true
80+
local X
81+
while true
82+
X = N*(1-vprime)
83+
s = trunc(Int, X)
84+
s < q1 && break
85+
vprime = exp(-randexp(rng)/n)
86+
end
87+
88+
y = rand(rng)/q2
89+
lhs = exp(log(y)/(n-1))
90+
rhs = ((q1-s)/q1) * (N/(N-X))
91+
92+
if lhs <= rhs
93+
vprime = lhs/rhs
94+
break
95+
end
96+
97+
if n-1 > s
98+
bottom = N-n
99+
limit = N-s
100+
else
101+
bottom = N-s-1
102+
limit = q1
103+
end
104+
105+
top = N-1
106+
107+
while top >= limit
108+
y *= top/bottom
109+
bottom -= 1
110+
top -= 1
111+
end
112+
113+
if log(y) < (n-1)*(log(N)-log(N-X))
114+
vprime = exp(-randexp(rng)/(n-1))
115+
break
116+
end
117+
vprime = exp(-randexp(rng)/n)
118+
end
119+
return s, vprime
120+
end
121+
122+
@inline function new_state(it, s, i, nv, Nv, q1, q2, threshold)
123+
i += s+1
124+
Nv -= s+1
125+
nv -= 1
126+
q1 -= s
127+
q2 = q1/Nv
128+
threshold -= it.alpha
129+
return i, nv, Nv, q1, q2, threshold
130+
end
131+
132+
@inline function seqsample_a!(rng::AbstractRNG, n, k)
133+
if k > 1
134+
i = 0
135+
q = (n-k)/n
136+
while q > rand(rng)
137+
i += 1
138+
n -= 1
139+
q *= (n-k)/n
140+
end
141+
return i
142+
else
143+
return trunc(Int, n * rand(rng))
144+
end
145+
end
146+
147+
Base.IteratorEltype(::SeqSampleIter) = Base.HasEltype()
148+
Base.eltype(::SeqSampleIter) = Int
149+
Base.IteratorSize(::SeqSampleIter) = Base.HasLength()
150+
Base.length(s::SeqSampleIter) = s.n

src/SortedSamplingMulti.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ struct SampleMultiAlgORD{T,R,I,D} <: AbstractStreamSample
99
end
1010
end
1111

12-
function StreamSample{T}(rng::AbstractRNG, iter, n, N, ::AlgORDSWR) where T
13-
return SampleMultiAlgORD{T}(rng, iter, n, SortedRandRangeIter(rng, 1:N, n))
12+
function StreamSample{T}(rng::AbstractRNG, iter, n, N, ::AlgD) where T
13+
return SampleMultiAlgORD{T}(rng, iter, min(n, N), SeqSampleIter(rng, N, min(n, N)))
1414
end
15-
function StreamSample{T}(rng::AbstractRNG, iter, n, N, ::AlgORDS) where T
16-
return SampleMultiAlgORD{T}(rng, iter, min(n, N), sort!(sample(rng, 1:N, min(n, N); replace=false)))
15+
function StreamSample{T}(rng::AbstractRNG, iter, n, N, ::AlgORDSWR) where T
16+
return SampleMultiAlgORD{T}(rng, iter, n, SeqSampleIterWR(rng, N, n))
1717
end
1818

1919
@inline function Base.iterate(s::SampleMultiAlgORD)

src/StreamSampling.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using StatsBase
1717

1818
export fit!, merge!, value, ordvalue, nobs, itsample
1919
export AbstractReservoirSample, ReservoirSample, StreamSample
20-
export AlgL, AlgR, AlgRSWRSKIP, AlgARes, AlgAExpJ, AlgWRSWRSKIP, AlgORDS, AlgORDSWR
20+
export AlgL, AlgR, AlgRSWRSKIP, AlgARes, AlgAExpJ, AlgWRSWRSKIP, AlgD, AlgORDSWR
2121

2222
struct ImmutSample end
2323
struct MutSample end
@@ -34,8 +34,11 @@ abstract type ReservoirAlgorithm <: StreamAlgorithm end
3434
"""
3535
Implements random sampling without replacement. To be used with [`StreamSample`](@ref)
3636
or [`itsample`](@ref).
37+
38+
Adapted from algorithm D described in "An Efficient Algorithm for Sequential Random Sampling,
39+
J. S. Vitter, 1987".
3740
"""
38-
struct AlgORDS <: StreamAlgorithm end
41+
struct AlgD <: StreamAlgorithm end
3942

4043
"""
4144
Implements random stream sampling with replacement. To be used with [`StreamSample`](@ref)

0 commit comments

Comments
 (0)