Skip to content

Commit 627129f

Browse files
devmotionnalimilan
andauthored
Fix sampling with UnitWeights (#963)
* Fix sampling with `UnitWeights` * Apply suggestions from code review Co-authored-by: Milan Bouchet-Valat <[email protected]> --------- Co-authored-by: Milan Bouchet-Valat <[email protected]>
1 parent 24c21a9 commit 627129f

File tree

2 files changed

+86
-5
lines changed

2 files changed

+86
-5
lines changed

src/sampling.jl

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,9 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv)
603603
sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)]
604604
sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv)
605605

606+
# Specialization for `UnitWeights`
607+
sample(rng::AbstractRNG, wv::UnitWeights) = rand(rng, 1:length(wv))
608+
606609
"""
607610
direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
608611
@@ -633,6 +636,22 @@ end
633636
direct_sample!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
634637
direct_sample!(default_rng(), a, wv, x)
635638

639+
# Specialization for `UnitWeights`
640+
function direct_sample!(
641+
rng::AbstractRNG, a::AbstractArray, wv::UnitWeights, x::AbstractArray,
642+
)
643+
if length(a) != length(wv)
644+
throw(DimensionMismatch(LazyString(
645+
"Number of samples (",
646+
length(a),
647+
") and sample weights (",
648+
length(wv),
649+
") must be equal.",
650+
)))
651+
end
652+
return direct_sample!(rng, a, x)
653+
end
654+
636655
"""
637656
alias_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
638657
@@ -741,7 +760,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
741760
# calculate keys for all items
742761
keys = randexp(rng, n)
743762
for i in 1:n
744-
keys[i] = wv.values[i]/keys[i]
763+
keys[i] = wv[i]/keys[i]
745764
end
746765

747766
# return items with largest keys
@@ -787,7 +806,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
787806
s = 0
788807
for _s in 1:n
789808
s = _s
790-
w = wv.values[s]
809+
w = wv[s]
791810
w < 0 && error("Negative weight found in weight vector at index $s")
792811
if w > 0
793812
i += 1
@@ -802,7 +821,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
802821
threshold = pq[1].first
803822

804823
for i in s+1:n
805-
w = wv.values[i]
824+
w = wv[i]
806825
w < 0 && error("Negative weight found in weight vector at index $i")
807826
w > 0 || continue
808827
key = w/randexp(rng)
@@ -861,7 +880,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
861880
s = 0
862881
for _s in 1:n
863882
s = _s
864-
w = wv.values[s]
883+
w = wv[s]
865884
w < 0 && error("Negative weight found in weight vector at index $s")
866885
if w > 0
867886
i += 1
@@ -877,7 +896,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
877896
X = threshold*randexp(rng)
878897

879898
for i in s+1:n
880-
w = wv.values[i]
899+
w = wv[i]
881900
w < 0 && error("Negative weight found in weight vector at index $i")
882901
w > 0 || continue
883902
X -= w
@@ -958,6 +977,20 @@ sample(a::AbstractArray, wv::AbstractWeights, dims::Dims;
958977
replace::Bool=true, ordered::Bool=false) =
959978
sample(default_rng(), a, wv, dims; replace=replace, ordered=ordered)
960979

980+
# Specialization for `UnitWeights`
981+
function sample!(rng::AbstractRNG, a::AbstractArray, wv::UnitWeights, x::AbstractArray; replace::Bool=true, ordered::Bool=false)
982+
if length(a) != length(wv)
983+
throw(DimensionMismatch(LazyString(
984+
"Number of samples (",
985+
length(a),
986+
") and sample weights (",
987+
length(wv),
988+
") must be equal.",
989+
)))
990+
end
991+
return sample!(rng, a, x; replace, ordered)
992+
end
993+
961994
# wsample interface
962995

963996
"""

test/sampling.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,51 @@ end
297297
end
298298
end
299299
end
300+
301+
# Custom weights without `values` field
302+
struct YAUnitWeights <: AbstractWeights{Int, Int, Vector{Int}}
303+
n::Int
304+
end
305+
Base.sum(wv::YAUnitWeights) = wv.n
306+
Base.length(wv::YAUnitWeights) = wv.n
307+
Base.isempty(wv::YAUnitWeights) = iszero(wv.n)
308+
Base.size(wv::YAUnitWeights) = (wv.n,)
309+
Base.axes(wv::YAUnitWeights) = (Base.OneTo(wv.n),)
310+
function Base.getindex(wv::YAUnitWeights, i::Int)
311+
@boundscheck checkbounds(wv, i)
312+
return 1
313+
end
314+
315+
@testset "issue #950" begin
316+
# Sampling with unit weights behaves the same as sampling without weights
317+
Random.seed!(123)
318+
xs = sample(1:100, uweights(100), 10; replace=false)
319+
Random.seed!(123)
320+
@test xs == sample(1:100, 10; replace=false)
321+
322+
Random.seed!(123)
323+
x = sample(uweights(100))
324+
Random.seed!(123)
325+
@test x == sample(1:100)
326+
327+
Random.seed!(123)
328+
xs = direct_sample!(1:100, uweights(100), Vector{Int}(undef, 10))
329+
Random.seed!(123)
330+
@test xs == direct_sample!(1:100, Vector{Int}(undef, 10))
331+
332+
# Errors
333+
@test_throws DimensionMismatch("Number of samples (100) and sample weights (99) must be equal.") sample(1:100, uweights(99), 10; replace=false)
334+
@test_throws DimensionMismatch("Number of samples (80) and sample weights (53) must be equal.") direct_sample!(1:80, uweights(53), Vector{Int}(undef, 10))
335+
336+
# Custom unit weights don't error and behave the same as sampling with `Weights`
337+
Random.seed!(123)
338+
xs = sample(1:100, YAUnitWeights(100), 10; replace=false)
339+
Random.seed!(123)
340+
@test xs == sample(1:100, weights(ones(Int, 100)), 10; replace=false)
341+
for f in (StatsBase.efraimidis_a_wsample_norep!, StatsBase.efraimidis_ares_wsample_norep!, StatsBase.efraimidis_aexpj_wsample_norep!)
342+
Random.seed!(123)
343+
xs = f(1:100, YAUnitWeights(100), Vector{Int}(undef, 10))
344+
Random.seed!(123)
345+
@test xs == f(1:100, weights(ones(Int, 100)), Vector{Int}(undef, 10))
346+
end
347+
end

0 commit comments

Comments
 (0)