Skip to content

Commit 2c480a7

Browse files
committed
Improve tests
1 parent ac67239 commit 2c480a7

8 files changed

+93
-32
lines changed

src/UnweightedSamplingMulti.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ function OnlineStatsBase.value(s::Union{SampleMultiAlgR, SampleMultiAlgL})
226226
end
227227
function OnlineStatsBase.value(s::SampleMultiAlgRSWRSKIP)
228228
if nobs(s) < length(s.value)
229-
return sample(s.rng, s.value[1:nobs(s)], length(s.value))
229+
return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], length(s.value))
230230
else
231231
return s.value
232232
end

src/WeightedSamplingMulti.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const OrdWeighted = BinaryHeap{Tuple{T, Int64, Float64}, Base.Order.By{typeof(la
55
seen_k::Int
66
n::Int
77
const rng::R
8-
const value::BH
8+
value::BH
99
end
1010
const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, SampleMultiAlgARes_Mut{<:OrdWeighted}}
1111

@@ -15,7 +15,7 @@ const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, Sam
1515
seen_k::Int
1616
const n::Int
1717
const rng::R
18-
const value::BH
18+
value::BH
1919
end
2020
const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, SampleMultiAlgAExpJ_Mut{<:OrdWeighted}}
2121

@@ -157,15 +157,23 @@ end
157157

158158
function Base.empty!(s::SampleMultiAlgARes_Mut)
159159
s.seen_k = 0
160-
empty!(s.value)
160+
if s isa SampleMultiAlgWRSWRSKIP_Mut{<:Vector}
161+
s.value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), extract_T(s.value)[])
162+
else
163+
s.value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), extract_T(s.value)[])
164+
end
161165
sizehint!(s.value, s.n)
162166
return s
163167
end
164168
function Base.empty!(s::SampleMultiAlgAExpJ_Mut)
165169
s.state = 0.0
166170
s.min_priority = 0.0
167171
s.seen_k = 0
168-
empty!(s.value)
172+
if s isa SampleMultiAlgWRSWRSKIP_Mut{<:Vector}
173+
s.value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), extract_T(s.value)[])
174+
else
175+
s.value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), extract_T(s.value)[])
176+
end
169177
sizehint!(s.value, s.n)
170178
return s
171179
end
@@ -176,6 +184,8 @@ function Base.empty!(s::SampleMultiAlgWRSWRSKIP_Mut)
176184
return s
177185
end
178186

187+
extract_T(::DataStructures.BinaryHeap{T}) where T = T
188+
179189
function Base.merge(ss::SampleMultiAlgWRSWRSKIP...)
180190
newvalue = reduce_samples(TypeUnion(), ss...)
181191
skip_w = sum(getfield(s, :skip_w) for s in ss)
@@ -256,7 +266,7 @@ function OnlineStatsBase.value(s::Union{SampleMultiAlgARes, SampleMultiAlgAExpJ}
256266
end
257267
function OnlineStatsBase.value(s::SampleMultiAlgWRSWRSKIP)
258268
if nobs(s) < length(s.value)
259-
return sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value))
269+
return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value))
260270
else
261271
return s.value
262272
end

test/benchmark_tests.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
iter_no_f = (x for x in 1:10^2)
44
iter = Iterators.filter(x -> x != 10, 1:10^2)
55
wv(el) = 1.0
6-
for m in (:(AlgS()), AlgR(), AlgL(), AlgRSWRSKIP())
7-
for size in (nothing, 10)
6+
println("\nUnweighted Methods\n")
7+
for size in (nothing, 10)
8+
for m in (AlgORDSWR(), AlgD(), AlgR(), AlgL(), AlgRSWRSKIP())
9+
size == nothing && m === AlgD() && continue
810
size == nothing && m === AlgL() && continue
911
size == nothing && m === AlgR() && continue
1012
s = size == nothing ? () : (size,)
11-
b = @benchmark itsample($rng, $(m == :(AlgS()) ? iter_no_f : iter), $s..., $m) evals=1
13+
b = @benchmark itsample($rng, $(m == AlgD() || m == AlgORDSWR() ? iter_no_f : iter), $s..., $m) evals=1
1214
mstr = "$m $(size == nothing ? :single : :multi)"
1315
print(mstr * repeat(" ", 35-length(mstr)))
1416
print(" --> Time: $(median(b.times)) ns |")
1517
println(" Memory: $(b.memory) bytes")
1618
end
1719
end
18-
for m in (AlgARes(), AlgAExpJ(), AlgWRSWRSKIP())
19-
for size in (nothing, 10)
20+
println("\nWeighted Methods\n")
21+
for size in (nothing, 10)
22+
for m in (AlgARes(), AlgAExpJ(), AlgWRSWRSKIP())
2023
size == nothing && m === AlgARes() && continue
2124
size == nothing && m === AlgAExpJ() && continue
2225
s = size == nothing ? () : (size,)
@@ -27,4 +30,5 @@
2730
println(" Memory: $(b.memory) bytes")
2831
end
2932
end
33+
println()
3034
end

test/empty_tests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
@testset "empty! tests" begin
3+
rng = StableRNG(43)
4+
rs = ReservoirSample{Int}(AlgRSWRSKIP())
5+
fit!(rs, 1)
6+
empty!(rs)
7+
@test value(rs) === nothing
8+
rs = ReservoirSample{Int}(AlgWRSWRSKIP())
9+
fit!(rs, 1, 1.0)
10+
empty!(rs)
11+
@test value(rs) === nothing
12+
for m in (AlgR(), AlgL(), AlgRSWRSKIP())
13+
rs = ReservoirSample{Int}(1, m)
14+
fit!(rs, 1)
15+
empty!(rs)
16+
@test value(rs) == Int64[]
17+
end
18+
for m in (AlgARes(), AlgAExpJ(), AlgWRSWRSKIP())
19+
rs = ReservoirSample{Int}(1, m)
20+
fit!(rs, 1, 1.0)
21+
empty!(rs)
22+
@test value(rs) == Int64[]
23+
end
24+
end

test/merge_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
@testset "merge tests" begin
2+
@testset "merge/merge! tests" begin
33
rng = StableRNG(43)
44
iters = (1:2, 3:10)
55
reps = 10^5

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ using StreamSampling
1616
include("weighted_sampling_single_tests.jl")
1717
include("weighted_sampling_multi_tests.jl")
1818
include("merge_tests.jl")
19+
include("empty_tests.jl")
1920
include("benchmark_tests.jl")
2021
end

test/unweighted_sampling_multi_tests.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,37 @@
4949
@test all(x -> a <= x <= b, value(rs))
5050
@test nobs(rs) == 10
5151

52-
rng = StableRNG(42)
53-
iters = (a:b, Iterators.filter(x -> x != b + 1, a:b+1))
52+
rngs = (StableRNG(42), StableRNG(43))
53+
iters = (a:b, Iterators.filter(x -> x != b + 1, a:b+1), (a:floor(Int, b/2), (floor(Int, b/2)+1):b))
5454
sizes = (2, 3)
5555
for it in iters
5656
for size in sizes
5757
reps = 10^(size+2)
5858
dict_res = Dict{Vector, Int}()
5959
for _ in 1:reps
60-
s = shuffle!(rng, itsample(rng, it, size, method; ordered=ordered))
60+
if typeof(it) <: Tuple
61+
if method == AlgRSWRSKIP() && ordered == false
62+
s = shuffle!(rngs[1], itsample(rngs, it, size))
63+
else
64+
break
65+
end
66+
else
67+
s = shuffle!(rngs[1], itsample(rngs[1], it, size, method; ordered=ordered))
68+
end
6169
if s in keys(dict_res)
6270
dict_res[s] += 1
6371
else
6472
dict_res[s] = 1
6573
end
6674
end
67-
cases = method == AlgRSWRSKIP() ? 10^size : factorial(10)/factorial(10-size)
68-
ps_exact = [1/cases for _ in 1:cases]
69-
count_est = collect(values(dict_res))
75+
if !(typeof(it) <: Tuple) || (method == AlgRSWRSKIP() && ordered == false)
76+
cases = method == AlgRSWRSKIP() ? 10^size : factorial(10)/factorial(10-size)
77+
ps_exact = [1/cases for _ in 1:cases]
78+
count_est = collect(values(dict_res))
7079

71-
chisq_test = ChisqTest(count_est, ps_exact)
72-
@test pvalue(chisq_test) > 0.05
80+
chisq_test = ChisqTest(count_est, ps_exact)
81+
@test pvalue(chisq_test) > 0.05
82+
end
7383
end
7484
end
7585
end

test/weighted_sampling_multi_tests.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,31 +76,43 @@ end
7676
@test nobs(rs) == 10
7777

7878
weight2(el) = el <= 5 ? 1.0 : 2.0
79-
rng = StableRNG(41)
80-
iters = (a:b, Iterators.filter(x -> x != b+1, a:b+1))
79+
weight3(el) = el <= 5 ? 1.0 : 2.0
80+
wfuncs = (weight2, weight3)
81+
rngs = (StableRNG(41), StableRNG(42))
82+
iters = (a:b, Iterators.filter(x -> x != b+1, a:b+1), (a:floor(Int, b/2), (floor(Int, b/2)+1):b))
8183
sizes = (1, 2)
8284
for it in iters
8385
for size in sizes
8486
reps = 10^(size+3)
8587
dict_res = Dict{Vector, Int}()
8688
for _ in 1:reps
87-
s = shuffle!(rng, itsample(rng, it, weight2, size, method; ordered=ordered))
89+
if typeof(it) <: Tuple
90+
if method == AlgWRSWRSKIP() && ordered == false
91+
s = shuffle!(rngs[1], itsample(rngs, it, wfuncs, size))
92+
else
93+
break
94+
end
95+
else
96+
s = shuffle!(rngs[1], itsample(rngs[1], it, wfuncs[1], size, method; ordered=ordered))
97+
end
8898
if s in keys(dict_res)
8999
dict_res[s] += 1
90100
else
91101
dict_res[s] = 1
92102
end
93103
end
94-
cases = method == AlgWRSWRSKIP() ? 10^size : factorial(10)/factorial(10-size)
95-
pairs_dict = collect(pairs(dict_res))
96-
if method == AlgWRSWRSKIP()
97-
ps_exact = [prob_replace(k) for (k, v) in pairs_dict]
98-
else
99-
ps_exact = [prob_no_replace(k) for (k, v) in pairs_dict if length(unique(k)) == size]
104+
if !(typeof(it) <: Tuple) || (method == AlgWRSWRSKIP() && ordered == false)
105+
cases = method == AlgWRSWRSKIP() ? 10^size : factorial(10)/factorial(10-size)
106+
pairs_dict = collect(pairs(dict_res))
107+
if method == AlgWRSWRSKIP()
108+
ps_exact = [prob_replace(k) for (k, v) in pairs_dict]
109+
else
110+
ps_exact = [prob_no_replace(k) for (k, v) in pairs_dict if length(unique(k)) == size]
111+
end
112+
count_est = [v for (k, v) in pairs_dict]
113+
chisq_test = ChisqTest(count_est, ps_exact)
114+
@test pvalue(chisq_test) > 0.05
100115
end
101-
count_est = [v for (k, v) in pairs_dict]
102-
chisq_test = ChisqTest(count_est, ps_exact)
103-
@test pvalue(chisq_test) > 0.05
104116
end
105117
end
106118
end

0 commit comments

Comments
 (0)