Skip to content

Commit 34404f2

Browse files
authored
Better performance for sampling with replacement (#107)
1 parent 8e3b897 commit 34404f2

File tree

3 files changed

+33
-41
lines changed

3 files changed

+33
-41
lines changed

src/UnweightedSamplingMulti.jl

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,13 @@ end
107107
end
108108
elseif s.skip_k < s.seen_k
109109
p = 1/s.seen_k
110-
z = (1-p)^(n-3)
111-
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
112-
k = choose(n, p, q, z)
113-
@inbounds begin
114-
if k == 1
115-
r = rand(s.rng, 1:n)
116-
s.value[r] = el
117-
update_order_single!(s, r)
118-
else
119-
for j in 1:k
120-
r = rand(s.rng, j:n)
121-
s.value[r] = el
122-
s.value[r], s.value[j] = s.value[j], s.value[r]
123-
update_order_multi!(s, r, j)
124-
end
125-
end
110+
z = exp((n-4)*log1p(-p))
111+
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0))
112+
k = @inline choose(n, p, q, z)
113+
@inbounds for j in 1:k
114+
r = rand(s.rng, j:n)
115+
s.value[r], s.value[j] = s.value[j], el
116+
update_order_multi!(s, r, j)
126117
end
127118
s = recompute_skip!(s, n)
128119
end
@@ -164,20 +155,22 @@ function recompute_skip!(s::SampleMultiAlgL, n)
164155
return s
165156
end
166157
function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n)
167-
q = rand(s.rng)^(1/n)
158+
q = exp(-randexp(s.rng)/n)
168159
@update s.skip_k = ceil(Int, s.seen_k/q)-1
169160
return s
170161
end
171162

172163
function choose(n, p, q, z)
173164
m = 1-p
174165
s = z
175-
z = s*m*m*(m + n*p)
166+
z = s*m*m*m*(m + n*p)
176167
z > q && return 1
177-
z += n*p*(n-1)*p*s*m/2
168+
z += n*p*(n-1)*p*s*m*m/2
178169
z > q && return 2
179-
z += n*p*(n-1)*p*(n-2)*p*s/6
170+
z += n*p*(n-1)*p*(n-2)*p*s*m/6
180171
z > q && return 3
172+
z += n*p*(n-1)*p*(n-2)*p*(n-3)*p*s/24
173+
z > q && return 4
181174
b = Binomial(n, p)
182175
return quantile(b, q)
183176
end
@@ -226,7 +219,11 @@ function OnlineStatsBase.value(s::Union{SampleMultiAlgR, SampleMultiAlgL})
226219
end
227220
function OnlineStatsBase.value(s::SampleMultiAlgRSWRSKIP)
228221
if nobs(s) < length(s.value)
229-
return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], length(s.value))
222+
if nobs(s) == 0
223+
return s.value[1:0]
224+
else
225+
return sample(s.rng, s.value[1:nobs(s)], length(s.value))
226+
end
230227
else
231228
return s.value
232229
end
@@ -241,7 +238,11 @@ function ordvalue(s::Union{SampleMultiOrdAlgR, SampleMultiOrdAlgL})
241238
end
242239
function ordvalue(s::SampleMultiOrdAlgRSWRSKIP)
243240
if nobs(s) < length(s.value)
244-
return sample(s.rng, s.value[1:nobs(s)], length(s.value); ordered=true)
241+
if nobs(s) == 0
242+
return s.value[1:0]
243+
else
244+
return sample(s.rng, s.value[1:nobs(s)], length(s.value); ordered=true)
245+
end
245246
else
246247
return s.value[sortperm(s.ord)]
247248
end

src/WeightedSamplingMulti.jl

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ end
124124
@inbounds s.value[s.seen_k] = el
125125
@inbounds s.weights[s.seen_k] = w
126126
if s.seen_k == n
127-
new_values = sample(s.rng, s.value, weights(s.weights), n; ordered = is_ordered(s))
127+
new_values = sample(s.rng, s.value, Weights(s.weights, s.state), n; ordered = is_ordered(s))
128128
@inbounds for i in 1:n
129129
s.value[i] = new_values[i]
130130
end
@@ -133,22 +133,13 @@ end
133133
end
134134
elseif s.skip_w <= s.state
135135
p = w/s.state
136-
z = (1-p)^(n-3)
137-
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
138-
k = choose(n, p, q, z)
139-
@inbounds begin
140-
if k == 1
141-
r = rand(s.rng, 1:n)
142-
s.value[r] = el
143-
update_order_single!(s, r)
144-
else
145-
for j in 1:k
146-
r = rand(s.rng, j:n)
147-
s.value[r] = el
148-
s.value[r], s.value[j] = s.value[j], s.value[r]
149-
update_order_multi!(s, r, j)
150-
end
151-
end
136+
z = exp((n-4)*log1p(-p))
137+
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0))
138+
k = @inline choose(n, p, q, z)
139+
@inbounds for j in 1:k
140+
r = rand(s.rng, j:n)
141+
s.value[r], s.value[j] = s.value[j], el
142+
update_order_multi!(s, r, j)
152143
end
153144
s = @inline recompute_skip!(s, n)
154145
end
@@ -233,7 +224,7 @@ function recompute_skip!(s::SampleMultiAlgAExpJ)
233224
return s
234225
end
235226
function recompute_skip!(s::SampleMultiAlgWRSWRSKIP, n)
236-
q = rand(s.rng)^(1/n)
227+
q = exp(-randexp(s.rng)/n)
237228
@update s.skip_w = s.state/q
238229
return s
239230
end

test/unweighted_sampling_multi_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
@test all(x -> a <= x <= b, value(rs))
5050
@test nobs(rs) == 10
5151

52-
rngs = (StableRNG(42), StableRNG(43))
52+
rngs = (StableRNG(46), StableRNG(47))
5353
iters = (a:b, Iterators.filter(x -> x != b + 1, a:b+1), (a:floor(Int, b/2), (floor(Int, b/2)+1):b))
5454
sizes = (2, 3)
5555
for it in iters

0 commit comments

Comments
 (0)