Skip to content
29 changes: 28 additions & 1 deletion src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,29 @@ sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray,
sampler!(rng, a, wv, x)
end

"""
This algorithm generates a sorted sample with replacement by
adapting the classic result that the cumulative sum of n+1
exponentially-distributed random numbers divided by the overall sum
(dropping the last) is a sorted sample from a uniform[0,1]
"""
function uniform_orderstat_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
exp_rands = randexp(rng, k)
sorted_rands = cumsum(exp_rands)
cum_step = (sorted_rands[end] + randexp(rng)) / n
@inbounds for i in eachindex(x)
j = ceil(Int, sorted_rands[i] / cum_step)
x[i] = a[j]
end
return x
end

### draw a pair of distinct integers in [1:n]

"""
Expand Down Expand Up @@ -500,7 +523,11 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;

if replace # with replacement
if ordered
sample_ordered!(direct_sample!, rng, a, x)
if k <= 10
sample_ordered!(direct_sample!, rng, a, x)
else
uniform_orderstat_sample!(rng, a, x)
end
else
direct_sample!(rng, a, x)
end
Expand Down
11 changes: 8 additions & 3 deletions test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,19 @@ test_rng_use(direct_sample!, 1:10, zeros(Int, 6))
a = sample(3:12, n)
check_sample_wrep(a, (3, 12), 5.0e-3; ordered=false)

rng = StableRNG(1)
for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
r = rev ? reverse(3:12) : (3:12)
r = T===Int ? r : T.(r)
aa = Int.(sample(r, n; ordered=true))
aa = Int.(sample(rng, r, n; ordered=true))
check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=true, rev=rev)

aa = Int.(sample(r, 10; ordered=true))
check_sample_wrep(aa, (3, 12), 0; ordered=true, rev=rev)
aa = Int[]
for i in 1:Int(n/10)
bb = Int.(sample(rng, r, 10; ordered=true))
append!(aa, bb)
end
check_sample_wrep(sort!(aa, rev=rev), (3, 12), 5.0e-3; ordered=true, rev=rev)
end

@test StatsBase._storeindices(1, 1, BigFloat) == StatsBase._storeindices(1, 1, BigFloat) == false
Expand Down