Skip to content

Commit 5f8ef16

Browse files
bkaminsnalimilan
authored andcommitted
Fix weighted sampling without replacement (#239)
The original article assumes strictly positive weights, so this skips zero weights. Additionally it is now checked if there are not less positive weights in `wv` as required sample size.
1 parent e200a72 commit 5f8ef16

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

src/sampling.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -515,16 +515,28 @@ function efraimidis_ares_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abst
515515

516516
# initialize priority queue
517517
pq = Vector{Pair{Float64,Int}}(k)
518-
@inbounds for i in 1:k
519-
pq[i] = (wv.values[i]/randexp() => i)
518+
i = 0
519+
s = 0
520+
@inbounds for s in 1:n
521+
w = wv.values[s]
522+
w < 0 && error("Negative weight found in weight vector at index $s")
523+
if w > 0
524+
i += 1
525+
pq[i] = (w/randexp() => s)
526+
end
527+
i >= k && break
520528
end
529+
i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)"))
521530
heapify!(pq)
522531

523532
# set threshold
524533
@inbounds threshold = pq[1].first
525534

526-
@inbounds for i in k+1:n
527-
key = wv.values[i]/randexp()
535+
@inbounds for i in s+1:n
536+
w = wv.values[i]
537+
w < 0 && error("Negative weight found in weight vector at index $i")
538+
w > 0 || continue
539+
key = w/randexp()
528540

529541
# if key is larger than the threshold
530542
if key > threshold
@@ -561,17 +573,28 @@ function efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abs
561573

562574
# initialize priority queue
563575
pq = Vector{Pair{Float64,Int}}(k)
564-
@inbounds for i in 1:k
565-
pq[i] = (wv.values[i]/randexp() => i)
576+
i = 0
577+
s = 0
578+
@inbounds for s in 1:n
579+
w = wv.values[s]
580+
w < 0 && error("Negative weight found in weight vector at index $s")
581+
if w > 0
582+
i += 1
583+
pq[i] = (w/randexp() => s)
584+
end
585+
i >= k && break
566586
end
587+
i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)"))
567588
heapify!(pq)
568589

569590
# set threshold
570591
@inbounds threshold = pq[1].first
571592
X = threshold*randexp()
572593

573-
@inbounds for i in k+1:n
594+
@inbounds for i in s+1:n
574595
w = wv.values[i]
596+
w < 0 && error("Negative weight found in weight vector at index $i")
597+
w > 0 || continue
575598
X -= w
576599
X <= 0 || continue
577600

test/sampling.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,28 @@ check_sample_norep(a, (3, 12), 0; ordered=false)
149149

150150
a = sample(3:12, 5; replace=false, ordered=true)
151151
check_sample_norep(a, (3, 12), 0; ordered=true)
152+
153+
# test of weighted sampling without replacement
154+
a = [1:10;]
155+
wv = WeightVec([zeros(6); 1:4])
156+
x = vcat([sample(a, wv, 1, replace=false) for j in 1:100000]...)
157+
@test minimum(x) == 7
158+
@test maximum(x) == 10
159+
@test maximum(abs(proportions(x) - (1:4)/10)) < 0.01
160+
161+
x = vcat([sample(a, wv, 2, replace=false) for j in 1:50000]...)
162+
exact2 = [0.117261905, 0.220634921, 0.304166667, 0.357936508]
163+
@test minimum(x) == 7
164+
@test maximum(x) == 10
165+
@test maximum(abs(proportions(x) - exact2)) < 0.01
166+
167+
x = vcat([sample(a, wv, 4, replace=false) for j in 1:10000]...)
168+
@test minimum(x) == 7
169+
@test maximum(x) == 10
170+
@test maximum(abs(proportions(x) - 0.25)) == 0
171+
172+
@test_throws DimensionMismatch sample(a, wv, 5, replace=false)
173+
174+
wv = WeightVec([zeros(5); 1:4; -1])
175+
@test_throws ErrorException sample(a, wv, 1, replace=false)
176+

0 commit comments

Comments
 (0)