|
| 1 | + |
1 | 2 | using StreamSampling, StatsBase |
2 | 3 | using Random, Printf, BenchmarkTools |
3 | | -using CairoMakie |
| 4 | + |
| 5 | +function samplesum(rng, stream, n, replace) |
| 6 | + pop = collect(stream) |
| 7 | + return sum(sample(rng, pop, n; replace)) |
| 8 | +end |
| 9 | +function samplesum(rng, stream, wf, n, replace) |
| 10 | + pop = collect(stream) |
| 11 | + weights = wf.(pop) |
| 12 | + return sum(sample(rng, pop, Weights(weights), n; replace)) |
| 13 | +end |
| 14 | + |
| 15 | +function rsvsamplesum(rng, stream, wf, n, alg) |
| 16 | + rs = ReservoirSampler{Int}(rng, n, alg; mutable=false) |
| 17 | + if alg in (AlgL(), AlgRSWRSKIP()) |
| 18 | + for i in stream |
| 19 | + rs = fit!(rs, i) |
| 20 | + end |
| 21 | + else |
| 22 | + for i in stream |
| 23 | + rs = fit!(rs, i, wf(i)) |
| 24 | + end |
| 25 | + end |
| 26 | + return sum(value(rs)) |
| 27 | +end |
| 28 | + |
| 29 | +function strsamplesum(rng, stream, wf, n, alg, W=nothing) |
| 30 | + W == nothing && (W = sum(wf(x) for x in stream)) |
| 31 | + st = if alg in (AlgD(), AlgORDSWR()) |
| 32 | + StreamSampler{Int}(rng, stream, n, W, alg) |
| 33 | + else |
| 34 | + StreamSampler{Int}(rng, stream, w, n, W, alg) |
| 35 | + end |
| 36 | + return sum(st) |
| 37 | +end |
4 | 38 |
|
5 | 39 | rng = Xoshiro(42); |
6 | | -stream = Iterators.filter(x -> x != 1, 1:10^8); |
7 | | -pop = collect(stream); |
8 | | -w(el) = Float64(el); |
9 | | -weights = Weights(w.(stream)); |
10 | | - |
11 | | -algs = (AlgL(), AlgRSWRSKIP(), AlgAExpJ(), AlgWRSWRSKIP()); |
12 | | -algsweighted = (AlgAExpJ(), AlgWRSWRSKIP()); |
13 | | -algsreplace = (AlgRSWRSKIP(), AlgWRSWRSKIP()); |
| 40 | +stream = Iterators.filter(x -> x != 0, 1:10^8); |
| 41 | +W = 10^8 |
| 42 | +w(el) = 1.0; |
| 43 | +w2(el) = 1; |
| 44 | + |
| 45 | +const algrsv = (AlgL(), AlgRSWRSKIP(), AlgAExpJ(), AlgWRSWRSKIP()) |
| 46 | +const algstr = (AlgD(), AlgORDSWR(), nothing, AlgORDWSWR()) |
14 | 47 | sizes = (10^4, 10^5, 10^6, 10^7) |
15 | 48 |
|
16 | | -p = Dict((0, 0) => 1, (0, 1) => 2, (1, 0) => 3, (1, 1) => 4); |
17 | | -m_times = Matrix{Vector{Float64}}(undef, (3, 4)); |
| 49 | +m_times = Matrix{Vector{Float64}}(undef, (4, 4)); |
18 | 50 | for i in eachindex(m_times) m_times[i] = Float64[] end |
19 | | -m_mems = Matrix{Vector{Float64}}(undef, (3, 4)); |
| 51 | +m_mems = Matrix{Vector{Float64}}(undef, (4, 4)); |
20 | 52 | for i in eachindex(m_mems) m_mems[i] = Float64[] end |
21 | 53 |
|
22 | | -for m in algs |
23 | | - for size in sizes |
24 | | - replace = m in algsreplace |
25 | | - weighted = m in algsweighted |
26 | | - if weighted |
27 | | - b1 = @benchmark itsample($rng, $stream, $w, $size, $m) seconds=20 |
28 | | - b2 = @benchmark sample($rng, collect($stream), Weights($w.($stream)), $size; replace = $replace) seconds=20 |
29 | | - b3 = @benchmark sample($rng, $pop, $weights, $size; replace = $replace) seconds=20 |
30 | | - else |
31 | | - b1 = @benchmark itsample($rng, $stream, $size, $m) evals=1 seconds=20 |
32 | | - b2 = @benchmark sample($rng, collect($stream), $size; replace = $replace) seconds=20 |
33 | | - b3 = @benchmark sample($rng, $pop, $size; replace = $replace) seconds=20 |
| 54 | +for size in sizes |
| 55 | + i = 0 |
| 56 | + for weighted in (false, true) |
| 57 | + for replace in (false, true) |
| 58 | + if weighted |
| 59 | + b1 = @benchmark samplesum($rng, $stream, $w, $size, $replace) seconds=20 |
| 60 | + else |
| 61 | + b1 = @benchmark samplesum($rng, $stream, $size, $replace) seconds=20 |
| 62 | + end |
| 63 | + i += 1 |
| 64 | + push!(m_times[1, i], median(b1.times) * 1e-6) |
| 65 | + push!(m_mems[1, i], b1.memory * 1e-6) |
34 | 66 | end |
35 | | - ts = [median(b1.times), median(b2.times), median(b3.times)] .* 1e-6 |
36 | | - ms = [b1.memory, b2.memory, b3.memory] .* 1e-6 |
37 | | - c = p[(weighted, replace)] |
38 | | - for r in 1:3 |
39 | | - push!(m_times[r, c], ts[r]) |
40 | | - push!(m_mems[r, c], ms[r]) |
| 67 | + end |
| 68 | +end |
| 69 | +for n in sizes |
| 70 | + i = 0 |
| 71 | + for alg in algrsv |
| 72 | + b2 = @benchmark rsvsamplesum($rng, $stream, $w, $n, $alg) seconds=20 |
| 73 | + i += 1 |
| 74 | + push!(m_times[2, i], median(b2.times) * 1e-6) |
| 75 | + push!(m_mems[2, i], b2.memory * 1e-6) |
| 76 | + end |
| 77 | +end |
| 78 | +for n in sizes |
| 79 | + i = 0 |
| 80 | + for alg in algstr |
| 81 | + i += 1 |
| 82 | + alg == nothing && continue |
| 83 | + if alg in (AlgD(), AlgORDSWR()) |
| 84 | + b3 = @benchmark strsamplesum($rng, $stream, $w2, $n, $alg) seconds=20 |
| 85 | + b4 = @benchmark strsamplesum($rng, $stream, $w2, $n, $alg, $W) seconds=20 |
| 86 | + else |
| 87 | + b3 = @benchmark strsamplesum($rng, $stream, $w, $n, $alg) seconds=20 |
| 88 | + b4 = @benchmark strsamplesum($rng, $stream, $w, $n, $alg, $(Float64(W))) seconds=20 |
41 | 89 | end |
42 | | - println("c") |
| 90 | + push!(m_times[3, i], median(b3.times) * 1e-6) |
| 91 | + push!(m_mems[3, i], b3.memory * 1e-6) |
| 92 | + push!(m_times[4, i], median(b4.times) * 1e-6) |
| 93 | + push!(m_mems[4, i], b4.memory * 1e-6) |
43 | 94 | end |
44 | 95 | end |
45 | 96 |
|
| 97 | +using CairoMakie |
| 98 | + |
46 | 99 | f = Figure(fontsize = 9,); |
47 | | -axs = [Axis(f[i, j], yscale = log10, xscale = log10) for i in 1:4 for j in 1:2]; |
| 100 | +axs = [Axis(f[i, j], yscale = log10, xscale = log10, xgridstyle = :dot, |
| 101 | + ygridstyle = :dot) for i in 1:4 for j in 1:2]; |
48 | 102 |
|
49 | | -labels = ( |
50 | | - "stream-based\n(StreamSampling.itsample)", |
51 | | - "collection-based with setup\n(StatsBase.sample)", |
52 | | - "collection-based\n(StatsBase.sample)" |
53 | | -) |
| 103 | +labels = ("population", "reservoir", "stream", "stream - one pass" ) |
54 | 104 |
|
55 | | -markers = (:circle, :rect, :utriangle) |
| 105 | +markers = (:circle, :rect, :utriangle, :xcross) |
56 | 106 | a, b = 0, 0 |
57 | 107 |
|
58 | 108 | for j in 1:8 |
59 | 109 | m = j in (3, 4, 7, 8) ? m_mems : m_times |
60 | 110 | m == m_mems ? (a += 1) : (b += 1) |
61 | 111 | s = m == m_mems ? a : b |
62 | | - for i in 1:3 |
63 | | - scatterlines!(axs[j], [0.01, 0.1, 1, 10], m[i, s]; label = labels[i], marker = markers[i]) |
| 112 | + for i in 1:4 |
| 113 | + length(m[i, s]) != 4 && continue |
| 114 | + t = deepcopy(m[i, s]) |
| 115 | + scatterlines!(axs[j], [0.01, 0.1, 1, 10], t; label = labels[i], marker = markers[i], linestyle=(:dash, :dense)) |
| 116 | + end |
| 117 | + if j in (1,3,5,7) |
| 118 | + axs[j].ylabel = m == m_mems ? "memory (Mb)" : "time (ms)" |
64 | 119 | end |
65 | | - axs[j].ylabel = m == m_mems ? "memory (Mb)" : "time (ms)" |
66 | 120 | axs[j].xtickformat = x -> string.(x) .* "%" |
67 | | - j in (3, 4, 7, 8) && (axs[j].xlabel = "sample size") |
| 121 | + j in (7, 8) && (axs[j].xlabel = "sample size") |
68 | 122 | pr = j in (1, 2) ? "un" : "" |
69 | 123 | t = j in (1, 5) ? "out" : "" |
70 | 124 | j in (1, 2, 5, 6) && (axs[j].title = pr * "weighted with" * t * " replacement") |
71 | 125 | axs[j].titlegap = 8.0 |
72 | 126 | j in (1, 2, 5, 6) && hidexdecorations!(axs[j], grid = false) |
73 | 127 | end |
74 | 128 |
|
| 129 | +for i in 1:8 |
| 130 | + axs[i].yticks = LogTicks(WilkinsonTicks(4, k_min=4, k_max=6)) |
| 131 | +end |
| 132 | + |
| 133 | +linkyaxes!((axs[i] for i in [1,2,5,6])...) |
| 134 | +linkyaxes!((axs[i] for i in [3,4,7,8])...) |
| 135 | + |
| 136 | +for i in [2,4,6,8] |
| 137 | + axs[i].yticklabelsvisible = false |
| 138 | +end |
| 139 | +for i in [3,4] |
| 140 | + axs[i].xticklabelsvisible = false |
| 141 | +end |
| 142 | + |
| 143 | + |
75 | 144 | f[5, 1] = Legend(f, axs[1], framevisible = false, orientation = :horizontal, |
76 | 145 | halign = :center, padding=(248,0,0,0)) |
77 | 146 |
|
78 | | -Label(f[0, :], "Comparison between stream-based and collection-based algorithms", fontsize = 13, |
| 147 | +Label(f[0, :], "Performance of Sampling Algorithms on Iterators", fontsize = 13, |
79 | 148 | font=:bold) |
80 | 149 |
|
81 | | -save("comparison_stream_algs.png", f) |
82 | 150 | f |
| 151 | + |
| 152 | +save("comparison_stream_algs.pdf", f) |
| 153 | + |
0 commit comments