Skip to content

Commit 30605ae

Browse files
authored
Benchmark improvements (#109)
1 parent 9898f87 commit 30605ae

10 files changed

+156
-95
lines changed

benchmark/benchmark_comparison_non_stream_WWR.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,43 +125,43 @@ end
125125
rng = Xoshiro(42);
126126
rngs = Tuple(Xoshiro(rand(rng, 1:10000)) for _ in 1:Threads.nthreads());
127127

128-
a = collect(1:10^7);
128+
a = collect(1:10^8);
129129
wsa = Float64.(a);
130130

131131
times_other_parallel = Float64[]
132-
for i in 0:6
133-
b = @benchmark sample_parallel_2_pass($rngs, $a, $wsa, 10^$i)
132+
for i in 0:7
133+
b = @benchmark sample_parallel_2_pass($rngs, $a, $wsa, 10^$i) seconds=20
134134
push!(times_other_parallel, median(b.times)/10^6)
135135
println("other $(10^i): $(median(b.times)/10^6) ms")
136136
end
137137

138138
times_other = Float64[]
139-
for i in 0:6
140-
b = @benchmark sample($rng, $a, Weights($wsa), 10^$i; replace = true)
139+
for i in 0:7
140+
b = @benchmark sample($rng, $a, Weights($wsa), 10^$i; replace = true) seconds=20
141141
push!(times_other, median(b.times)/10^6)
142142
println("other $(10^i): $(median(b.times)/10^6) ms")
143143
end
144144

145145
## single thread
146146
times_single_thread = Float64[]
147-
for i in 0:6
148-
b = @benchmark weighted_reservoir_sample($rng, $a, $wsa, 10^$i)
147+
for i in 0:7
148+
b = @benchmark weighted_reservoir_sample($rng, $a, $wsa, 10^$i) seconds=20
149149
push!(times_single_thread, median(b.times)/10^6)
150150
println("sequential $(10^i): $(median(b.times)/10^6) ms")
151151
end
152152

153153
# multi thread 1 pass - 6 threads
154154
times_multi_thread = Float64[]
155-
for i in 0:6
156-
b = @benchmark weighted_reservoir_sample_parallel_1_pass($rngs, $a, $wsa, 10^$i)
155+
for i in 0:7
156+
b = @benchmark weighted_reservoir_sample_parallel_1_pass($rngs, $a, $wsa, 10^$i) seconds=20
157157
push!(times_multi_thread, median(b.times)/10^6)
158158
println("parallel $(10^i): $(median(b.times)/10^6) ms")
159159
end
160160

161161
# multi thread 2 pass - 6 threads
162162
times_multi_thread_2 = Float64[]
163-
for i in 0:6
164-
b = @benchmark weighted_reservoir_sample_parallel_2_pass($rngs, $a, $wsa, 10^$i)
163+
for i in 0:7
164+
b = @benchmark weighted_reservoir_sample_parallel_2_pass($rngs, $a, $wsa, 10^$i) seconds=20
165165
push!(times_multi_thread_2, median(b.times)/10^6)
166166
println("parallel $(10^i): $(median(b.times)/10^6) ms")
167167
end
@@ -170,13 +170,13 @@ py"""
170170
import numpy as np
171171
import timeit
172172
173-
a = np.arange(1, 10**7+1, dtype=np.int64);
174-
wsa = np.arange(1, 10**7+1, dtype=np.float64)
173+
a = np.arange(1, 10**8+1, dtype=np.int64);
174+
wsa = np.arange(1, 10**8+1, dtype=np.float64)
175175
p = wsa/np.sum(wsa);
176176
177177
def sample_times_numpy():
178178
times_numpy = []
179-
for i in range(7):
179+
for i in range(8):
180180
ts = []
181181
for j in range(11):
182182
t = timeit.timeit("np.random.choice(a, size=10**i, replace=True, p=p)",
@@ -196,20 +196,20 @@ ax1 = Axis(f[1, 1], yscale=log10, xscale=log10,
196196
yminorticksvisible = true, yminorgridvisible = true,
197197
yminorticks = IntervalsBetween(10))
198198

199-
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_numpy[2:end], label = "numpy.choice sequential", marker = :circle, markersize = 12, linestyle = :dot)
200-
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_other[2:end], label = "StatsBase.sample sequential", marker = :rect, markersize = 12, linestyle = :dot)
201-
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_other_parallel[2:end], label = "StatsBase.sample parallel (2 passes)", marker = :diamond, markersize = 12, linestyle = :dot)
202-
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_single_thread[2:end], label = "WRSWR-SKIP sequential", marker = :hexagon, markersize = 12, linestyle = :dot)
203-
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_multi_thread[2:end], label = "WRSWR-SKIP parallel (1 pass)", marker = :cross, markersize = 12, linestyle = :dot)
204-
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_multi_thread_2[2:end], label = "WRSWR-SKIP parallel (2 passes)", marker = :xcross, markersize = 12, linestyle = :dot)
205-
Legend(f[1,2], ax1, labelsize=10, framevisible = false)
199+
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_numpy[3:end], label = "numpy.choice sequential", marker = :circle, markersize = 12, linestyle = :dot)
200+
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_other[3:end], label = "StatsBase.sample sequential", marker = :rect, markersize = 12, linestyle = :dot)
201+
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_other_parallel[3:end], label = "StatsBase.sample parallel (2 passes)", marker = :diamond, markersize = 12, linestyle = :dot)
202+
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_single_thread[3:end], label = "WRSWR-SKIP sequential", marker = :hexagon, markersize = 12, linestyle = :dot)
203+
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_multi_thread[3:end], label = "WRSWR-SKIP parallel (1 pass)", marker = :cross, markersize = 12, linestyle = :dot)
204+
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_multi_thread_2[3:end], label = "WRSWR-SKIP parallel (2 passes)", marker = :xcross, markersize = 12, linestyle = :dot)
205+
Legend(f[2,1], ax1, labelsize=10, framevisible = false, orientation = :horizontal)
206206

207207
ax1.xtickformat = x -> string.(round.(x.*100, digits=10)) .* "%"
208208
ax1.title = "Comparison between weighted sampling algorithms in a non-streaming context"
209-
ax1.xticks = [10^(i)/10^7 for i in 1:6]
209+
ax1.xticks = [10^(i)/10^8 for i in 2:7]
210210

211211
ax1.xlabel = "sample ratio"
212212
ax1.ylabel = "time (ms)"
213213

214214
f
215-
save("comparison_WRSWR_SKIP_alg.png", f)
215+
save("comparison_WRSWR_SKIP_alg_no_stream.png", f)

benchmark/benchmark_comparison_stream.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ using Random, Printf, BenchmarkTools
33
using CairoMakie
44

55
rng = Xoshiro(42);
6-
stream = Iterators.filter(x -> x != 10, 1:10^7);
6+
stream = Iterators.filter(x -> x != 1, 1:10^8);
77
pop = collect(stream);
88
w(el) = Float64(el);
99
weights = Weights(w.(stream));
1010

1111
algs = (AlgL(), AlgRSWRSKIP(), AlgAExpJ(), AlgWRSWRSKIP());
1212
algsweighted = (AlgAExpJ(), AlgWRSWRSKIP());
1313
algsreplace = (AlgRSWRSKIP(), AlgWRSWRSKIP());
14-
sizes = (10^3, 10^4, 10^5, 10^6)
14+
sizes = (10^4, 10^5, 10^6, 10^7)
1515

1616
p = Dict((0, 0) => 1, (0, 1) => 2, (1, 0) => 3, (1, 1) => 4);
1717
m_times = Matrix{Vector{Float64}}(undef, (3, 4));
@@ -24,13 +24,13 @@ for m in algs
2424
replace = m in algsreplace
2525
weighted = m in algsweighted
2626
if weighted
27-
b1 = @benchmark itsample($rng, $stream, $w, $size, $m) evals=1
28-
b2 = @benchmark sample($rng, collect($stream), Weights($w.($stream)), $size; replace = $replace) evals=1
29-
b3 = @benchmark sample($rng, $pop, $weights, $size; replace = $replace) evals=1
27+
b1 = @benchmark itsample($rng, $stream, $w, $size, $m) seconds=20
28+
b2 = @benchmark sample($rng, collect($stream), Weights($w.($stream)), $size; replace = $replace) seconds=20
29+
b3 = @benchmark sample($rng, $pop, $weights, $size; replace = $replace) seconds=20
3030
else
31-
b1 = @benchmark itsample($rng, $stream, $size, $m) evals=1
32-
b2 = @benchmark sample($rng, collect($stream), $size; replace = $replace) evals=1
33-
b3 = @benchmark sample($rng, $pop, $size; replace = $replace) evals=1
31+
b1 = @benchmark itsample($rng, $stream, $size, $m) evals=1 seconds=20
32+
b2 = @benchmark sample($rng, collect($stream), $size; replace = $replace) seconds=20
33+
b3 = @benchmark sample($rng, $pop, $size; replace = $replace) seconds=20
3434
end
3535
ts = [median(b1.times), median(b2.times), median(b3.times)] .* 1e-6
3636
ms = [b1.memory, b2.memory, b3.memory] .* 1e-6
@@ -39,6 +39,7 @@ for m in algs
3939
push!(m_times[r, c], ts[r])
4040
push!(m_mems[r, c], ms[r])
4141
end
42+
println("c")
4243
end
4344
end
4445

benchmark/benchmark_comparison_stream_WWR.jl

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,61 @@ a = Iterators.filter(x -> x != 1, 1:10^8)
6565
wv_const(x) = 1.0
6666
wv_incr(x) = Float64(x)
6767
wv_decr(x) = 1/x
68-
wvs = (wv_decr, wv_const, wv_incr)
68+
wvs = ((:wv_decr, wv_decr),
69+
(:wv_const, wv_const),
70+
(:wv_incr, wv_incr))
6971

70-
for wv in wvs
71-
for m in (AlgWRSWRSKIP(), AlgAExpJWR())
72-
for sz in [10^i for i in 0:7]
73-
b = @benchmark itsample($a, $wv, $sz, $m) seconds=10
74-
println(wv, " ", m, " ", sz, " ", median(b.times))
72+
benchs = []
73+
for (wvn, wv) in wvs
74+
for m in (AlgAExpJWR(), AlgWRSWRSKIP())
75+
bs = []
76+
for sz in [10^i for i in 3:7]
77+
b = @benchmark itsample($a, $wv, $sz, $m) seconds=20
78+
push!(bs, median(b.times))
79+
println(median(b.times))
7580
end
81+
push!(benchs, (wvn, m, bs))
82+
println(benchs)
7683
end
7784
end
7885

86+
using CairoMakie
87+
88+
f = Figure(backgroundcolor = RGBf(0.98, 0.98, 0.98), size = (1100, 700));
89+
90+
f.title = "Comparison between AExpJ-WR and WRSWR-SKIP Algorithms"
91+
92+
ax1 = Axis(f[1, 1], yscale=log10, xscale=log10,
93+
yminorticksvisible = true, yminorgridvisible = true,
94+
yminorticks = IntervalsBetween(10))
95+
ax2 = Axis(f[1, 2], yscale=log10, xscale=log10,
96+
yminorticksvisible = true, yminorgridvisible = true,
97+
yminorticks = IntervalsBetween(10))
98+
ax3 = Axis(f[1, 3], yscale=log10, xscale=log10,
99+
yminorticksvisible = true, yminorgridvisible = true,
100+
yminorticks = IntervalsBetween(10))
101+
102+
#ax4 = Axis(f[2, 1])
103+
104+
for x in benchs
105+
label = x[1] == :wv_const ? (x[2] == AlgAExpJWR() ? "ExpJ-WR" : "WRSWR-SKIP") : ""
106+
ax = x[1] == :wv_decr ? ax1 : (x[1] == :wv_const ? ax2 : ax3)
107+
marker = x[2] == AlgAExpJWR() ? :circle : (:xcross)
108+
scatterlines!(ax, [10^i/10^8 for i in 3:7], x[3] ./ 10^6, marker = marker,
109+
label = label, markersize = 12, linestyle = :dot)
110+
end
111+
112+
Legend(ax4, labelsize=10, framevisible = false, orientation = :horizontal)
113+
114+
for ax in [ax1, ax2, ax3]
115+
ax.xtickformat = x -> string.(round.(x.*100, digits=10)) .* "%"
116+
#ax.ytickformat = y -> y .* "^"
117+
ax.title = ax == ax1 ? "decreasing weights" : (ax == ax2 ? "constant weights" : "increasing weights")
118+
ax.xticks = [10^(i)/10^8 for i in 3:7]
119+
ax.yticks = [10^i for i in 2:4]
120+
ax.xlabel = "sample ratio"
121+
ax == ax1 && (ax.ylabel = "time (ms)")
122+
end
123+
124+
save("comparison_WRSWR_SKIP_alg_stream.png", f)
125+
f
-166 KB
Binary file not shown.
155 KB
Loading
133 KB
Loading

benchmark/comparison_stream_algs.png

15.2 KB
Loading

src/SamplingUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Base.eltype(::SeqSampleIterWR) = Int
3434
Base.IteratorSize(::SeqSampleIterWR) = Base.HasLength()
3535
Base.length(s::SeqSampleIterWR) = s.n
3636

37+
# courtesy of StatsBase.jl for part of the implementation
3738
struct SeqSampleIter{R}
3839
rng::R
3940
N::Int

0 commit comments

Comments
 (0)