Skip to content

Commit 66ad2d5

Browse files
authored
Improve sorted sampling (#103)
1 parent d9ef3f0 commit 66ad2d5

File tree

5 files changed

+38
-36
lines changed

5 files changed

+38
-36
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
fail-fast: false
1616
matrix:
1717
version:
18-
- '1.8'
18+
- '1.10'
1919
- '1'
2020
- 'nightly'
2121
os:

src/SamplingInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, me
157157
s = ReservoirSample(rng, iter_type, n, method, ImmutSample(), ordered ? Ord() : Unord())
158158
return update_all!(s, iter, ordered)
159159
else
160-
replace = method isa AlgL || method isa AlgR ? false : true
161-
sortedindices_sample(rng, iter, n; iter_type, replace, ordered)
160+
replace = method isa AlgL || method isa AlgR ? NoReplace() : Replace()
161+
sortedindices_sample(rng, iter, n, replace; iter_type, ordered)
162162
end
163163
end
164164
function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter))

src/SamplingUtils.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,26 @@ function infer_eltype(itr)
1313
ifelse(T2 !== Union{} && T2 <: T1, T2, T1)
1414
end
1515

16-
function sortedrandrange(rng, range, n)
17-
s = Vector{Int}(undef, n)
18-
curmax = -log(Float64(range.stop))
19-
for i in n:-1:1
20-
curmax += randexp(rng)/i
21-
@inbounds s[i] = ceil(Int, exp(-curmax))
22-
end
23-
return s
16+
get_sorted_indices(rng, n, N, ::Replace) = SortedRandRangeIter(rng, 1:N, n)
17+
get_sorted_indices(rng, n, N, ::NoReplace) = sort!(sample(rng, 1:N, n; replace=false))
18+
19+
struct SortedRandRangeIter{R}
20+
rng::R
21+
range::UnitRange{Int}
22+
n::Int
2423
end
2524

26-
function get_sorted_indices(rng, n, N, replace)
27-
replace == true && return sortedrandrange(rng, 1:N, n)
28-
return sort!(sample(rng, 1:N, n; replace=replace))
25+
@inline function Base.iterate(s::SortedRandRangeIter)
26+
curmax = -log(Float64(s.range.stop)) + randexp(s.rng)/s.n
27+
return (s.range.stop - ceil(Int, exp(-curmax)) + 1, (s.n-1, curmax))
28+
end
29+
@inline function Base.iterate(s::SortedRandRangeIter, state)
30+
state[1] == 0 && return nothing
31+
curmax = state[2] + randexp(s.rng)/state[1]
32+
return (s.range.stop - ceil(Int, exp(-curmax)) + 1, (state[1]-1, curmax))
2933
end
34+
35+
Base.IteratorEltype(::SortedRandRangeIter) = Base.HasEltype()
36+
Base.eltype(::SortedRandRangeIter) = Int
37+
Base.IteratorSize(::SortedRandRangeIter) = Base.HasLength()
38+
Base.length(s::SortedRandRangeIter) = s.n

src/SortedSamplingMulti.jl

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,29 @@ Algorithm which generates sorted random indices used to retrieve the sample
77
from the iterable. The number of elements in the iterable needs to be known
88
before starting the sampling.
99
"""
10-
function sortedindices_sample(rng, iter, n::Int;
11-
iter_type = infer_eltype(iter), replace = false, ordered = false)
10+
function sortedindices_sample(rng, iter, n::Int, replace;
11+
iter_type = infer_eltype(iter), ordered = false)
1212
N = length(iter)
1313
if N <= n
1414
reservoir = collect(iter)
15-
replace && return sample(rng, reservoir, n, ordered=ordered)
15+
replace isa Replace && return sample(rng, reservoir, n, ordered=ordered)
1616
return ordered ? reservoir : shuffle!(rng, reservoir)
1717
end
1818
reservoir = Vector{iter_type}(undef, n)
1919
indices = get_sorted_indices(rng, n, N, replace)
20-
first_idx = indices[1]
21-
el, state = iterate(iter)::Tuple
22-
if first_idx != 1
23-
el, state = skip_ahead_no_end(iter, state, first_idx - 2)
20+
curr_idx, state_idx = iterate(indices)
21+
el, state_el = iterate(iter)::Tuple
22+
for _ in 1:curr_idx-1
23+
el, state_el = iterate(iter, state_el)::Tuple
2424
end
2525
reservoir[1] = el
26-
i = 2
27-
@inbounds while i <= n
28-
skip_k = indices[i] - indices[i-1] - 1
29-
if skip_k >= 0
30-
el, state = skip_ahead_no_end(iter, state, skip_k)
26+
@inbounds for i in 2:n
27+
next_idx, state_idx = iterate(indices, state_idx)
28+
for _ in 1:next_idx-curr_idx
29+
el, state_el = iterate(iter, state_el)::Tuple
3130
end
3231
reservoir[i] = el
33-
i += 1
32+
curr_idx = next_idx
3433
end
3534
return ordered ? reservoir : shuffle!(rng, reservoir)
3635
end
37-
38-
function skip_ahead_no_end(iter, state, n)
39-
for _ in 1:n
40-
it = iterate(iter, state)::Tuple
41-
state = it[2]
42-
end
43-
it = iterate(iter, state)::Tuple
44-
return it
45-
end

src/StreamSampling.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ struct MutSample end
2525
struct Ord end
2626
struct Unord end
2727

28+
struct Replace end
29+
struct NoReplace end
30+
2831
abstract type AbstractReservoirSample <: OnlineStat{Any} end
2932

3033
abstract type ReservoirAlgorithm end

0 commit comments

Comments
 (0)