Skip to content

Commit a4e7efc

Browse files
authored
Update benchmark code
1 parent f075f97 commit a4e7efc

File tree

1 file changed

+115
-44
lines changed

1 file changed

+115
-44
lines changed
Lines changed: 115 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,153 @@
1+
12
using StreamSampling, StatsBase
23
using Random, Printf, BenchmarkTools
3-
using CairoMakie
4+
5+
function samplesum(rng, stream, n, replace)
6+
pop = collect(stream)
7+
return sum(sample(rng, pop, n; replace))
8+
end
9+
function samplesum(rng, stream, wf, n, replace)
10+
pop = collect(stream)
11+
weights = wf.(pop)
12+
return sum(sample(rng, pop, Weights(weights), n; replace))
13+
end
14+
15+
function rsvsamplesum(rng, stream, wf, n, alg)
16+
rs = ReservoirSampler{Int}(rng, n, alg; mutable=false)
17+
if alg in (AlgL(), AlgRSWRSKIP())
18+
for i in stream
19+
rs = fit!(rs, i)
20+
end
21+
else
22+
for i in stream
23+
rs = fit!(rs, i, wf(i))
24+
end
25+
end
26+
return sum(value(rs))
27+
end
28+
29+
function strsamplesum(rng, stream, wf, n, alg, W=nothing)
30+
W == nothing && (W = sum(wf(x) for x in stream))
31+
st = if alg in (AlgD(), AlgORDSWR())
32+
StreamSampler{Int}(rng, stream, n, W, alg)
33+
else
34+
StreamSampler{Int}(rng, stream, w, n, W, alg)
35+
end
36+
return sum(st)
37+
end
438

539
rng = Xoshiro(42);
6-
stream = Iterators.filter(x -> x != 1, 1:10^8);
7-
pop = collect(stream);
8-
w(el) = Float64(el);
9-
weights = Weights(w.(stream));
10-
11-
algs = (AlgL(), AlgRSWRSKIP(), AlgAExpJ(), AlgWRSWRSKIP());
12-
algsweighted = (AlgAExpJ(), AlgWRSWRSKIP());
13-
algsreplace = (AlgRSWRSKIP(), AlgWRSWRSKIP());
40+
stream = Iterators.filter(x -> x != 0, 1:10^8);
41+
W = 10^8
42+
w(el) = 1.0;
43+
w2(el) = 1;
44+
45+
const algrsv = (AlgL(), AlgRSWRSKIP(), AlgAExpJ(), AlgWRSWRSKIP())
46+
const algstr = (AlgD(), AlgORDSWR(), nothing, AlgORDWSWR())
1447
sizes = (10^4, 10^5, 10^6, 10^7)
1548

16-
p = Dict((0, 0) => 1, (0, 1) => 2, (1, 0) => 3, (1, 1) => 4);
17-
m_times = Matrix{Vector{Float64}}(undef, (3, 4));
49+
m_times = Matrix{Vector{Float64}}(undef, (4, 4));
1850
for i in eachindex(m_times) m_times[i] = Float64[] end
19-
m_mems = Matrix{Vector{Float64}}(undef, (3, 4));
51+
m_mems = Matrix{Vector{Float64}}(undef, (4, 4));
2052
for i in eachindex(m_mems) m_mems[i] = Float64[] end
2153

22-
for m in algs
23-
for size in sizes
24-
replace = m in algsreplace
25-
weighted = m in algsweighted
26-
if weighted
27-
b1 = @benchmark itsample($rng, $stream, $w, $size, $m) seconds=20
28-
b2 = @benchmark sample($rng, collect($stream), Weights($w.($stream)), $size; replace = $replace) seconds=20
29-
b3 = @benchmark sample($rng, $pop, $weights, $size; replace = $replace) seconds=20
30-
else
31-
b1 = @benchmark itsample($rng, $stream, $size, $m) evals=1 seconds=20
32-
b2 = @benchmark sample($rng, collect($stream), $size; replace = $replace) seconds=20
33-
b3 = @benchmark sample($rng, $pop, $size; replace = $replace) seconds=20
54+
for size in sizes
55+
i = 0
56+
for weighted in (false, true)
57+
for replace in (false, true)
58+
if weighted
59+
b1 = @benchmark samplesum($rng, $stream, $w, $size, $replace) seconds=20
60+
else
61+
b1 = @benchmark samplesum($rng, $stream, $size, $replace) seconds=20
62+
end
63+
i += 1
64+
push!(m_times[1, i], median(b1.times) * 1e-6)
65+
push!(m_mems[1, i], b1.memory * 1e-6)
3466
end
35-
ts = [median(b1.times), median(b2.times), median(b3.times)] .* 1e-6
36-
ms = [b1.memory, b2.memory, b3.memory] .* 1e-6
37-
c = p[(weighted, replace)]
38-
for r in 1:3
39-
push!(m_times[r, c], ts[r])
40-
push!(m_mems[r, c], ms[r])
67+
end
68+
end
69+
for n in sizes
70+
i = 0
71+
for alg in algrsv
72+
b2 = @benchmark rsvsamplesum($rng, $stream, $w, $n, $alg) seconds=20
73+
i += 1
74+
push!(m_times[2, i], median(b2.times) * 1e-6)
75+
push!(m_mems[2, i], b2.memory * 1e-6)
76+
end
77+
end
78+
for n in sizes
79+
i = 0
80+
for alg in algstr
81+
i += 1
82+
alg == nothing && continue
83+
if alg in (AlgD(), AlgORDSWR())
84+
b3 = @benchmark strsamplesum($rng, $stream, $w2, $n, $alg) seconds=20
85+
b4 = @benchmark strsamplesum($rng, $stream, $w2, $n, $alg, $W) seconds=20
86+
else
87+
b3 = @benchmark strsamplesum($rng, $stream, $w, $n, $alg) seconds=20
88+
b4 = @benchmark strsamplesum($rng, $stream, $w, $n, $alg, $(Float64(W))) seconds=20
4189
end
42-
println("c")
90+
push!(m_times[3, i], median(b3.times) * 1e-6)
91+
push!(m_mems[3, i], b3.memory * 1e-6)
92+
push!(m_times[4, i], median(b4.times) * 1e-6)
93+
push!(m_mems[4, i], b4.memory * 1e-6)
4394
end
4495
end
4596

97+
using CairoMakie
98+
4699
f = Figure(fontsize = 9,);
47-
axs = [Axis(f[i, j], yscale = log10, xscale = log10) for i in 1:4 for j in 1:2];
100+
axs = [Axis(f[i, j], yscale = log10, xscale = log10, xgridstyle = :dot,
101+
ygridstyle = :dot) for i in 1:4 for j in 1:2];
48102

49-
labels = (
50-
"stream-based\n(StreamSampling.itsample)",
51-
"collection-based with setup\n(StatsBase.sample)",
52-
"collection-based\n(StatsBase.sample)"
53-
)
103+
labels = ("population", "reservoir", "stream", "stream - one pass" )
54104

55-
markers = (:circle, :rect, :utriangle)
105+
markers = (:circle, :rect, :utriangle, :xcross)
56106
a, b = 0, 0
57107

58108
for j in 1:8
59109
m = j in (3, 4, 7, 8) ? m_mems : m_times
60110
m == m_mems ? (a += 1) : (b += 1)
61111
s = m == m_mems ? a : b
62-
for i in 1:3
63-
scatterlines!(axs[j], [0.01, 0.1, 1, 10], m[i, s]; label = labels[i], marker = markers[i])
112+
for i in 1:4
113+
length(m[i, s]) != 4 && continue
114+
t = deepcopy(m[i, s])
115+
scatterlines!(axs[j], [0.01, 0.1, 1, 10], t; label = labels[i], marker = markers[i], linestyle=(:dash, :dense))
116+
end
117+
if j in (1,3,5,7)
118+
axs[j].ylabel = m == m_mems ? "memory (Mb)" : "time (ms)"
64119
end
65-
axs[j].ylabel = m == m_mems ? "memory (Mb)" : "time (ms)"
66120
axs[j].xtickformat = x -> string.(x) .* "%"
67-
j in (3, 4, 7, 8) && (axs[j].xlabel = "sample size")
121+
j in (7, 8) && (axs[j].xlabel = "sample size")
68122
pr = j in (1, 2) ? "un" : ""
69123
t = j in (1, 5) ? "out" : ""
70124
j in (1, 2, 5, 6) && (axs[j].title = pr * "weighted with" * t * " replacement")
71125
axs[j].titlegap = 8.0
72126
j in (1, 2, 5, 6) && hidexdecorations!(axs[j], grid = false)
73127
end
74128

129+
for i in 1:8
130+
axs[i].yticks = LogTicks(WilkinsonTicks(4, k_min=4, k_max=6))
131+
end
132+
133+
linkyaxes!((axs[i] for i in [1,2,5,6])...)
134+
linkyaxes!((axs[i] for i in [3,4,7,8])...)
135+
136+
for i in [2,4,6,8]
137+
axs[i].yticklabelsvisible = false
138+
end
139+
for i in [3,4]
140+
axs[i].xticklabelsvisible = false
141+
end
142+
143+
75144
f[5, 1] = Legend(f, axs[1], framevisible = false, orientation = :horizontal,
76145
halign = :center, padding=(248,0,0,0))
77146

78-
Label(f[0, :], "Comparison between stream-based and collection-based algorithms", fontsize = 13,
147+
Label(f[0, :], "Performance of Sampling Algorithms on Iterators", fontsize = 13,
79148
font=:bold)
80149

81-
save("comparison_stream_algs.png", f)
82150
f
151+
152+
save("comparison_stream_algs.pdf", f)
153+

0 commit comments

Comments
 (0)