Skip to content

Commit 2c26a20

Browse files
committed
Improve interface
1 parent 66ad2d5 commit 2c26a20

18 files changed

+190
-118
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StreamSampling"
22
uuid = "ff63dad9-3335-55d8-95ec-f8139d39e468"
3-
version = "0.5.2"
3+
version = "0.6.0"
44

55
[deps]
66
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ empty!
1414
value
1515
ordvalue
1616
nobs
17+
StreamSample
1718
itsample
1819
```
1920

docs/src/index.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,23 @@ julia> value(rs)
4444
74
4545
```
4646

47+
If the total number of elements in the stream is known beforehand and the sampling is unweighted, it is
48+
also possible to iterate over a `StreamSample` like so
49+
50+
```julia
51+
julia> using StreamSampling
52+
53+
julia> iter = 1:100
54+
55+
julia> ss = StreamSample{Int}(iter, 5, 100);
56+
57+
julia> collect(ss)
58+
5-element Vector{Int64}:
59+
7
60+
9
61+
20
62+
49
63+
74
64+
```
65+
4766
Consult the [API page](https://juliadynamics.github.io/StreamSampling.jl/stable/api) for more information about the package interface.

src/SamplingInterface.jl

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
"""
3-
ReservoirSample([rng], T, method = AlgRSWRSKIP())
4-
ReservoirSample([rng], T, n::Int, method = AlgL(); ordered = false)
3+
ReservoirSample{T}([rng], method = AlgRSWRSKIP())
4+
ReservoirSample{T}([rng], n::Int, method = AlgL(); ordered = false)
55
66
Initializes a reservoir sample which can then be fitted with [`fit!`](@ref).
77
The first signature represents a sample where only a single element is collected.
@@ -10,19 +10,21 @@ they were collected with [`ordvalue`](@ref).
1010
1111
Look at the [`Sampling Algorithms`](@ref) section for the supported methods.
1212
"""
13-
function ReservoirSample(T, method::ReservoirAlgorithm = AlgRSWRSKIP())
14-
return ReservoirSample(Random.default_rng(), T, method, MutSample())
13+
struct ReservoirSample{T} 1 === 1 end
14+
15+
function ReservoirSample{T}(method::ReservoirAlgorithm = AlgRSWRSKIP()) where T
16+
return ReservoirSample{T}(Random.default_rng(), method, MutSample())
1517
end
16-
function ReservoirSample(rng::AbstractRNG, T, method::ReservoirAlgorithm = AlgRSWRSKIP())
17-
return ReservoirSample(rng, T, method, MutSample())
18+
function ReservoirSample{T}(rng::AbstractRNG, method::ReservoirAlgorithm = AlgRSWRSKIP()) where T
19+
return ReservoirSample{T}(rng, method, MutSample())
1820
end
19-
Base.@constprop :aggressive function ReservoirSample(T, n::Integer, method::ReservoirAlgorithm=AlgL();
20-
ordered = false)
21-
return ReservoirSample(Random.default_rng(), T, n, method, MutSample(), ordered ? Ord() : Unord())
21+
Base.@constprop :aggressive function ReservoirSample{T}(n::Integer, method::ReservoirAlgorithm=AlgL();
22+
ordered = false) where T
23+
return ReservoirSample{T}(Random.default_rng(), n, method, MutSample(), ordered ? Ord() : Unord())
2224
end
23-
Base.@constprop :aggressive function ReservoirSample(rng::AbstractRNG, T, n::Integer,
24-
method::ReservoirAlgorithm=AlgL(); ordered = false)
25-
return ReservoirSample(rng, T, n, method, MutSample(), ordered ? Ord() : Unord())
25+
Base.@constprop :aggressive function ReservoirSample{T}(rng::AbstractRNG, n::Integer,
26+
method::ReservoirAlgorithm=AlgL(); ordered = false) where T
27+
return ReservoirSample{T}(rng, n, method, MutSample(), ordered ? Ord() : Unord())
2628
end
2729

2830
"""
@@ -86,7 +88,6 @@ function Base.merge!(::AbstractReservoirSample)
8688
error("Abstract Version")
8789
end
8890

89-
9091
"""
9192
Base.merge(rs::AbstractReservoirSample...)
9293
@@ -98,6 +99,31 @@ function OnlineStatsBase.merge(::AbstractReservoirSample)
9899
error("Abstract Version")
99100
end
100101

102+
"""
103+
StreamSample{T}([rng], iter, n, [N], method = AlgS())
104+
105+
Initializes a stream sample, which can then be iterated over
106+
to return the sampling elements of the iterable `iter` which
107+
is assumed to have a eltype of `T`. The methods implemented in
108+
[`StreamSample`](@ref) require the knowledge of the total number
109+
of elements in the stream `N`, if not provided it is assumed to be
110+
available by calling `length(iter)`.
111+
"""
112+
struct StreamSample{T} 1 === 1 end
113+
114+
function StreamSample{T}(iter, n, N, method::StreamAlgorithm = AlgORDS()) where T
115+
return StreamSample{T}(Random.default_rng(), iter, n, N, method)
116+
end
117+
function StreamSample{T}(iter, n, method::StreamAlgorithm = AlgORDS()) where T
118+
return StreamSample{T}(Random.default_rng(), iter, n, length(iter), method)
119+
end
120+
function StreamSample{T}(rng::AbstractRNG, iter, n, method::StreamAlgorithm = AlgORDS()) where T
121+
return StreamSample{T}(rng, iter, n, length(iter), method)
122+
end
123+
function StreamSample{T}(rng::AbstractRNG, iter, n, N, method::StreamAlgorithm = AlgORDS()) where T
124+
return StreamSample{T}(rng, iter, n, N, method)
125+
end
126+
101127
"""
102128
itsample([rng], iter, method = AlgRSWRSKIP())
103129
itsample([rng], iter, wfunc, method = AlgWRSWRSKIP())
@@ -145,37 +171,38 @@ end
145171
Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, method = AlgRSWRSKIP();
146172
iter_type = infer_eltype(iter))
147173
if Base.IteratorSize(iter) isa Base.SizeUnknown
148-
s = ReservoirSample(rng, iter_type, method, ImmutSample())
174+
s = ReservoirSample{iter_type}(rng, method, ImmutSample())
149175
return update_all!(s, iter)
150176
else
151-
return sortedindices_sample(rng, iter)
177+
return sorted_sample_single(rng, iter)
152178
end
153179
end
154180
Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, method = AlgL();
155181
iter_type = infer_eltype(iter), ordered = false)
156182
if Base.IteratorSize(iter) isa Base.SizeUnknown
157-
s = ReservoirSample(rng, iter_type, n, method, ImmutSample(), ordered ? Ord() : Unord())
183+
s = ReservoirSample{iter_type}(rng, n, method, ImmutSample(), ordered ? Ord() : Unord())
158184
return update_all!(s, iter, ordered)
159185
else
160-
replace = method isa AlgL || method isa AlgR ? NoReplace() : Replace()
161-
sortedindices_sample(rng, iter, n, replace; iter_type, ordered)
186+
m = method isa AlgL || method isa AlgR ? AlgORDS() : AlgORDSWR()
187+
s = collect(StreamSample{iter_type}(rng, iter, n, length(iter), m))
188+
return ordered ? s : shuffle!(rng, s)
162189
end
163190
end
164191
function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter))
165-
s = ReservoirSample(rng, iter_type, method, ImmutSample())
192+
s = ReservoirSample{iter_type}(rng, method, ImmutSample())
166193
return update_all!(s, iter, wv)
167194
end
168195
Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, wv::Function, n::Int, method = AlgAExpJ();
169196
iter_type = infer_eltype(iter), ordered = false)
170-
s = ReservoirSample(rng, iter_type, n, method, ImmutSample(), ordered ? Ord() : Unord())
197+
s = ReservoirSample{iter_type}(rng, n, method, ImmutSample(), ordered ? Ord() : Unord())
171198
return update_all!(s, iter, ordered, wv)
172199
end
173200
function itsample(rngs::Tuple, iters::Tuple, n::Int,; iter_types = infer_eltype.(iters))
174201
n_it = length(iters)
175202
vs = Vector{Vector{Union{iter_types...}}}(undef, n_it)
176203
ps = Vector{Float64}(undef, n_it)
177204
Threads.@threads for i in 1:n_it
178-
s = ReservoirSample(rngs[i], iter_types[i], n, AlgRSWRSKIP(), ImmutSample(), Unord())
205+
s = ReservoirSample{iter_types[i]}(rngs[i], n, AlgRSWRSKIP(), ImmutSample(), Unord())
179206
vs[i], ps[i] = update_all_p!(s, iters[i])
180207
end
181208
ps /= sum(ps)
@@ -186,7 +213,7 @@ function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types =
186213
vs = Vector{Vector{Union{iter_types...}}}(undef, n_it)
187214
ps = Vector{Float64}(undef, n_it)
188215
Threads.@threads for i in 1:n_it
189-
s = ReservoirSample(rngs[i], iter_types[i], n, AlgWRSWRSKIP(), ImmutSample(), Unord())
216+
s = ReservoirSample{iter_types[i]}(rngs[i], n, AlgWRSWRSKIP(), ImmutSample(), Unord())
190217
vs[i], ps[i] = update_all_p!(s, iters[i], wfuncs[i])
191218
end
192219
ps /= sum(ps)

src/SamplingUtils.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ function infer_eltype(itr)
1313
ifelse(T2 !== Union{} && T2 <: T1, T2, T1)
1414
end
1515

16-
get_sorted_indices(rng, n, N, ::Replace) = SortedRandRangeIter(rng, 1:N, n)
17-
get_sorted_indices(rng, n, N, ::NoReplace) = sort!(sample(rng, 1:N, n; replace=false))
18-
1916
struct SortedRandRangeIter{R}
2017
rng::R
2118
range::UnitRange{Int}

src/SortedSamplingMulti.jl

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,43 @@
11

2-
"""
3-
sortedindices_sample(rng, iter)
4-
sortedindices_sample(rng, iter, n; replace = false, ordered = false)
5-
6-
Algorithm which generates sorted random indices used to retrieve the sample
7-
from the iterable. The number of elements in the iterable needs to be known
8-
before starting the sampling.
9-
"""
10-
function sortedindices_sample(rng, iter, n::Int, replace;
11-
iter_type = infer_eltype(iter), ordered = false)
12-
N = length(iter)
13-
if N <= n
14-
reservoir = collect(iter)
15-
replace isa Replace && return sample(rng, reservoir, n, ordered=ordered)
16-
return ordered ? reservoir : shuffle!(rng, reservoir)
2+
struct SampleMultiAlgORD{T,R,I,D} <: AbstractStreamSample
3+
rng::R
4+
it::I
5+
n::Int
6+
inds::D
7+
function SampleMultiAlgORD{T}(rng::R, it::I, n, inds::D) where {T,R,I,D}
8+
return new{T,R,I,D}(rng, it, n, inds)
179
end
18-
reservoir = Vector{iter_type}(undef, n)
19-
indices = get_sorted_indices(rng, n, N, replace)
20-
curr_idx, state_idx = iterate(indices)
10+
end
11+
12+
function StreamSample{T}(rng::AbstractRNG, iter, n, N, ::AlgORDSWR) where T
13+
return SampleMultiAlgORD{T}(rng, iter, n, SortedRandRangeIter(rng, 1:N, n))
14+
end
15+
function StreamSample{T}(rng::AbstractRNG, iter, n, N, ::AlgORDS) where T
16+
return SampleMultiAlgORD{T}(rng, iter, min(n, N), sort!(sample(rng, 1:N, min(n, N); replace=false)))
17+
end
18+
19+
@inline function Base.iterate(s::SampleMultiAlgORD)
20+
indices, iter = s.inds, s.it
21+
curr_idx, state_idx = iterate(indices)::Tuple
2122
el, state_el = iterate(iter)::Tuple
2223
for _ in 1:curr_idx-1
2324
el, state_el = iterate(iter, state_el)::Tuple
2425
end
25-
reservoir[1] = el
26-
@inbounds for i in 2:n
27-
next_idx, state_idx = iterate(indices, state_idx)
28-
for _ in 1:next_idx-curr_idx
29-
el, state_el = iterate(iter, state_el)::Tuple
30-
end
31-
reservoir[i] = el
32-
curr_idx = next_idx
26+
return (el, (el, state_el, curr_idx, state_idx))
27+
end
28+
@inline function Base.iterate(s::SampleMultiAlgORD, state)
29+
el, state_el, curr_idx, state_idx = state
30+
indices, iter = s.inds, s.it
31+
it_indices = iterate(indices, state_idx)
32+
it_indices === nothing && return nothing
33+
next_idx, state_idx = it_indices
34+
for _ in 1:next_idx-curr_idx
35+
el, state_el = iterate(iter, state_el)::Tuple
3336
end
34-
return ordered ? reservoir : shuffle!(rng, reservoir)
37+
return (el, (el, state_el, next_idx, state_idx))
3538
end
39+
40+
Base.IteratorEltype(::SampleMultiAlgORD) = Base.HasEltype()
41+
Base.eltype(::SampleMultiAlgORD{T}) where T = T
42+
Base.IteratorSize(::SampleMultiAlgORD) = Base.HasLength()
43+
Base.length(s::SampleMultiAlgORD) = s.n

src/SortedSamplingSingle.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
function sortedindices_sample(rng, iter; kwargs...)
2+
function sorted_sample_single(rng, iter)
33
k = rand(rng, 1:length(iter))
44
for (i, el) in enumerate(iter)
55
i == k && return el

src/StreamSampling.jl

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Random
1616
using StatsBase
1717

1818
export fit!, merge!, value, ordvalue, nobs, itsample
19-
export AbstractReservoirSample, ReservoirSample
19+
export AbstractReservoirSample, ReservoirSample, StreamSample
2020
export AlgL, AlgR, AlgRSWRSKIP, AlgARes, AlgAExpJ, AlgWRSWRSKIP
2121

2222
struct ImmutSample end
@@ -25,56 +25,76 @@ struct MutSample end
2525
struct Ord end
2626
struct Unord end
2727

28-
struct Replace end
29-
struct NoReplace end
30-
28+
abstract type AbstractStreamSample end
3129
abstract type AbstractReservoirSample <: OnlineStat{Any} end
3230

33-
abstract type ReservoirAlgorithm end
31+
abstract type StreamAlgorithm end
32+
abstract type ReservoirAlgorithm <: StreamAlgorithm end
33+
34+
"""
35+
Implements random sampling without replacement. To be used with [`StreamSample`](@ref)
36+
or [`itsample`](@ref).
37+
"""
38+
struct AlgORDS <: StreamAlgorithm end
39+
40+
"""
41+
Implements random stream sampling with replacement. To be used with [`StreamSample`](@ref)
42+
or [`itsample`](@ref).
43+
44+
Adapted from algorithm 4 described in "Generating Sorted Lists of Random Numbers, J. L. Bentley
45+
et al., 1980".
46+
"""
47+
struct AlgORDSWR <: StreamAlgorithm end
3448

3549
"""
36-
Implements random sampling without replacement.
50+
Implements random reservoir sampling without replacement. To be used with [`ReservoirSample`](@ref)
51+
or [`itsample`](@ref).
3752
3853
Adapted from algorithm R described in "Random sampling with a reservoir, J. S. Vitter, 1985".
3954
"""
4055
struct AlgR <: ReservoirAlgorithm end
4156

4257
"""
43-
Implements random sampling without replacement.
58+
Implements random reservoir sampling without replacement. To be used with [`ReservoirSample`](@ref)
59+
or [`itsample`](@ref).
4460
4561
Adapted from algorithm L described in "Random sampling with a reservoir, J. S. Vitter, 1985".
4662
"""
4763
struct AlgL <: ReservoirAlgorithm end
4864

4965
"""
50-
Implements random sampling with replacement.
66+
Implements random reservoir sampling with replacement. To be used with [`ReservoirSample`](@ref)
67+
or [`itsample`](@ref).
5168
52-
Adapted fron algorithm RSWR_SKIP described in "Reservoir-based Random Sampling with Replacement from
69+
Adapted fron algorithm RSWR-SKIP described in "Reservoir-based Random Sampling with Replacement from
5370
Data Stream, B. Park et al., 2008".
5471
"""
5572
struct AlgRSWRSKIP <: ReservoirAlgorithm end
5673

5774
"""
58-
Implements weighted random sampling without replacement.
75+
Implements weighted random reservoir sampling without replacement. To be used with [`ReservoirSample`](@ref)
76+
or [`itsample`](@ref).
5977
60-
Adapted from algorithm A-Res described in "Weighted random sampling with a reservoir,
61-
P. S. Efraimidis et al., 2006".
78+
Adapted from algorithm A-Res described in "Weighted random sampling with a reservoir, P. S. Efraimidis
79+
et al., 2006".
6280
"""
6381
struct AlgARes <: ReservoirAlgorithm end
6482

6583
"""
66-
Implements weighted random sampling without replacement.
84+
Implements weighted random reservoir sampling without replacement. To be used with [`ReservoirSample`](@ref)
85+
or [`itsample`](@ref).
6786
68-
Adapted from algorithm A-ExpJ described in "Weighted random sampling with a reservoir,
69-
P. S. Efraimidis et al., 2006".
87+
Adapted from algorithm A-ExpJ described in "Weighted random sampling with a reservoir, P. S. Efraimidis
88+
et al., 2006".
7089
"""
7190
struct AlgAExpJ <: ReservoirAlgorithm end
7291

7392
"""
74-
Implements weighted random sampling with replacement.
93+
Implements weighted random reservoir sampling with replacement. To be used with [`ReservoirSample`](@ref)
94+
or [`itsample`](@ref).
7595
76-
Adapted from algorithm WRSWR_SKIP described in "A Skip-based Algorithm for Weighted Reservoir
77-
Sampling with Replacement, A. Meligrana, 2024".
96+
Adapted from algorithm WRSWR-SKIP described in "Weighted Reservoir Sampling with Replacement from Multiple
97+
Data Streams, A. Meligrana, 2024".
7898
"""
7999
struct AlgWRSWRSKIP <: ReservoirAlgorithm end
80100

0 commit comments

Comments
 (0)