@@ -603,6 +603,9 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv)
603603sample(rng:: AbstractRNG , a:: AbstractArray , wv:: AbstractWeights ) = a[sample(rng, wv)]
604604sample(a:: AbstractArray , wv:: AbstractWeights ) = sample(default_rng(), a, wv)
605605
606+ # Specialization for `UnitWeights`
607+ sample(rng:: AbstractRNG , wv:: UnitWeights ) = rand(rng, 1 : length(wv))
608+
606609"""
607610 direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
608611
633636direct_sample!(a:: AbstractArray , wv:: AbstractWeights , x:: AbstractArray ) =
634637 direct_sample!(default_rng(), a, wv, x)
635638
639+ # Specialization for `UnitWeights`
640+ function direct_sample!(
641+ rng:: AbstractRNG , a:: AbstractArray , wv:: UnitWeights , x:: AbstractArray ,
642+ )
643+ if length(a) != length(wv)
644+ throw(DimensionMismatch(LazyString(
645+ " Number of samples (" ,
646+ length(a),
647+ " ) and sample weights (" ,
648+ length(wv),
649+ " ) must be equal." ,
650+ )))
651+ end
652+ return direct_sample!(rng, a, x)
653+ end
654+
636655"""
637656 alias_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
638657
@@ -741,7 +760,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
741760 # calculate keys for all items
742761 keys = randexp(rng, n)
743762 for i in 1 : n
744- keys[i] = wv. values [i]/ keys[i]
763+ keys[i] = wv[i]/ keys[i]
745764 end
746765
747766 # return items with largest keys
@@ -787,7 +806,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
787806 s = 0
788807 for _s in 1 : n
789808 s = _s
790- w = wv. values [s]
809+ w = wv[s]
791810 w < 0 && error(" Negative weight found in weight vector at index $s " )
792811 if w > 0
793812 i += 1
@@ -802,7 +821,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
802821 threshold = pq[1 ]. first
803822
804823 for i in s+ 1 : n
805- w = wv. values [i]
824+ w = wv[i]
806825 w < 0 && error(" Negative weight found in weight vector at index $i " )
807826 w > 0 || continue
808827 key = w/ randexp(rng)
@@ -861,7 +880,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
861880 s = 0
862881 for _s in 1 : n
863882 s = _s
864- w = wv. values [s]
883+ w = wv[s]
865884 w < 0 && error(" Negative weight found in weight vector at index $s " )
866885 if w > 0
867886 i += 1
@@ -877,7 +896,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
877896 X = threshold* randexp(rng)
878897
879898 for i in s+ 1 : n
880- w = wv. values [i]
899+ w = wv[i]
881900 w < 0 && error(" Negative weight found in weight vector at index $i " )
882901 w > 0 || continue
883902 X -= w
@@ -958,6 +977,20 @@ sample(a::AbstractArray, wv::AbstractWeights, dims::Dims;
958977 replace:: Bool = true , ordered:: Bool = false ) =
959978 sample(default_rng(), a, wv, dims; replace= replace, ordered= ordered)
960979
980+ # Specialization for `UnitWeights`
981+ function sample!(rng:: AbstractRNG , a:: AbstractArray , wv:: UnitWeights , x:: AbstractArray ; replace:: Bool = true , ordered:: Bool = false )
982+ if length(a) != length(wv)
983+ throw(DimensionMismatch(LazyString(
984+ " Number of samples (" ,
985+ length(a),
986+ " ) and sample weights (" ,
987+ length(wv),
988+ " ) must be equal." ,
989+ )))
990+ end
991+ return sample!(rng, a, x; replace, ordered)
992+ end
993+
961994# wsample interface
962995
963996"""
0 commit comments