Skip to content

Commit bc20566

Browse files
authored
fix sequential sampling with replacement (#677)
1 parent 6ec4579 commit bc20566

File tree

3 files changed

+124
-42
lines changed

3 files changed

+124
-42
lines changed

src/sampling.jl

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,49 @@ function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
4242
end
4343
direct_sample!(a::AbstractArray, x::AbstractArray) = direct_sample!(Random.GLOBAL_RNG, a, x)
4444

45+
# check whether we can use T to store indices 1:n exactly, and
46+
# use some heuristics to decide whether it is beneficial for k samples
47+
# (true for a subset of hardware-supported numeric types)
48+
_storeindices(n, k, ::Type{T}) where {T<:Integer} = n typemax(T)
49+
_storeindices(n, k, ::Type{T}) where {T<:Union{Float32,Float64}} = k < 22 && n maxintfloat(T)
50+
_storeindices(n, k, ::Type{Complex{T}}) where {T} = _storeindices(n, k, T)
51+
_storeindices(n, k, ::Type{Rational{T}}) where {T} = k < 16 && _storeindices(n, k, T)
52+
_storeindices(n, k, T) = false
53+
storeindices(n, k, ::Type{T}) where {T<:Base.HWNumber} = _storeindices(n, k, T)
54+
storeindices(n, k, T) = false
55+
56+
# order results of a sampler that does not order automatically
57+
function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
58+
n, k = length(a), length(x)
59+
# todo: if eltype(x) <: Real && eltype(a) <: Real,
60+
# in some cases it might be faster to check
61+
# issorted(a) to see if we can just sort x
62+
if storeindices(n, k, eltype(x))
63+
sort!(sampler!(rng, Base.OneTo(n), x), by=real, lt=<)
64+
@inbounds for i = 1:k
65+
x[i] = a[Int(x[i])]
66+
end
67+
else
68+
indices = Array{Int}(undef, k)
69+
sort!(sampler!(rng, Base.OneTo(n), indices))
70+
@inbounds for i = 1:k
71+
x[i] = a[indices[i]]
72+
end
73+
end
74+
return x
75+
end
76+
77+
# special case of a range can be done more efficiently
78+
sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractRange, x::AbstractArray) =
79+
sort!(sampler!(rng, a, x), rev=step(a)<0)
80+
81+
# weighted case:
82+
sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray,
83+
wv::AbstractWeights, x::AbstractArray) =
84+
sample_ordered!(rng, a, x) do rng, a, x
85+
sampler!(rng, a, wv, x)
86+
end
87+
4588
### draw a pair of distinct integers in [1:n]
4689

4790
"""
@@ -396,21 +439,24 @@ Draw a random sample of `length(x)` elements from an array `a`
396439
and store the result in `x`. A polyalgorithm is used for sampling.
397440
Sampling probabilities are proportional to the weights given in `wv`,
398441
if provided. `replace` dictates whether sampling is performed with
399-
replacement and `order` dictates whether an ordered sample, also called
400-
a sequential sample, should be taken.
442+
replacement. `ordered` dictates whether
443+
an ordered sample (also called a sequential sample, i.e. a sample where
444+
items appear in the same order as in `a`) should be taken.
401445
402446
Optionally specify a random number generator `rng` as the first argument
403447
(defaults to `Random.GLOBAL_RNG`).
404448
"""
405449
function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
406450
replace::Bool=true, ordered::Bool=false)
451+
1 == firstindex(a) == firstindex(x) ||
452+
throw(ArgumentError("non 1-based arrays are not supported"))
407453
n = length(a)
408454
k = length(x)
409455
k == 0 && return x
410456

411457
if replace # with replacement
412458
if ordered
413-
sort!(direct_sample!(rng, a, x))
459+
sample_ordered!(direct_sample!, rng, a, x)
414460
else
415461
direct_sample!(rng, a, x)
416462
end
@@ -448,8 +494,9 @@ sample!(a::AbstractArray, x::AbstractArray; replace::Bool=true, ordered::Bool=fa
448494
Select a random, optionally weighted sample of size `n` from an array `a`
449495
using a polyalgorithm. Sampling probabilities are proportional to the weights
450496
given in `wv`, if provided. `replace` dictates whether sampling is performed
451-
with replacement and `order` dictates whether an ordered sample, also called
452-
a sequential sample, should be taken.
497+
with replacement. `ordered` dictates whether
498+
an ordered sample (also called a sequential sample, i.e. a sample where
499+
items appear in the same order as in `a`) should be taken.
453500
454501
Optionally specify a random number generator `rng` as the first argument
455502
(defaults to `Random.GLOBAL_RNG`).
@@ -468,8 +515,9 @@ sample(a::AbstractArray, n::Integer; replace::Bool=true, ordered::Bool=false) =
468515
Select a random, optionally weighted sample from an array `a` specifying
469516
the dimensions `dims` of the output array. Sampling probabilities are
470517
proportional to the weights given in `wv`, if provided. `replace` dictates
471-
whether sampling is performed with replacement and `order` dictates whether
472-
an ordered sample, also called a sequential sample, should be taken.
518+
whether sampling is performed with replacement. `ordered` dictates whether
519+
an ordered sample (also called a sequential sample, i.e. a sample where
520+
items appear in the same order as in `a`) should be taken.
473521
474522
Optionally specify a random number generator `rng` as the first argument
475523
(defaults to `Random.GLOBAL_RNG`).
@@ -781,7 +829,8 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(k \\log(k) \\lo
781829
processing time to draw ``k`` elements. It consumes ``O(k \\log(n / k))`` random numbers.
782830
"""
783831
function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
784-
wv::AbstractWeights, x::AbstractArray)
832+
wv::AbstractWeights, x::AbstractArray;
833+
ordered::Bool=false)
785834
n = length(a)
786835
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
787836
k = length(x)
@@ -824,24 +873,36 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
824873
threshold = pq[1].first
825874
X = threshold * randexp(rng)
826875
end
827-
828-
# fill output array with items in descending order
829-
@inbounds for i in k:-1:1
830-
x[i] = a[heappop!(pq).second]
876+
if ordered
877+
# fill output array with items sorted as in a
878+
sort!(pq, by=last)
879+
@inbounds for i in 1:k
880+
x[i] = a[pq[i].second]
881+
end
882+
else
883+
# fill output array with items in descending order
884+
@inbounds for i in k:-1:1
885+
x[i] = a[heappop!(pq).second]
886+
end
831887
end
832888
return x
833889
end
834-
efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
835-
efraimidis_aexpj_wsample_norep!(Random.GLOBAL_RNG, a, wv, x)
890+
efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
891+
ordered::Bool=false) =
892+
efraimidis_aexpj_wsample_norep!(Random.GLOBAL_RNG, a, wv, x; ordered=ordered)
836893

837894
function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
838895
replace::Bool=true, ordered::Bool=false)
896+
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
897+
throw(ArgumentError("non 1-based arrays are not supported"))
839898
n = length(a)
840899
k = length(x)
841900

842901
if replace
843902
if ordered
844-
sort!(direct_sample!(rng, a, wv, x))
903+
sample_ordered!(rng, a, wv, x) do rng, a, wv, x
904+
sample!(rng, a, wv, x; replace=true, ordered=false)
905+
end
845906
else
846907
if n < 40
847908
direct_sample!(rng, a, wv, x)
@@ -856,11 +917,7 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs
856917
end
857918
else
858919
k <= n || error("Cannot draw $n samples from $k samples without replacement.")
859-
860-
efraimidis_aexpj_wsample_norep!(rng, a, wv, x)
861-
if ordered
862-
sort!(x)
863-
end
920+
efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered)
864921
end
865922
return x
866923
end
@@ -889,8 +946,9 @@ sample(a::AbstractArray, wv::AbstractWeights, dims::Dims;
889946
890947
Select a weighted sample from an array `a` and store the result in `x`. Sampling
891948
probabilities are proportional to the weights given in `w`. `replace` dictates
892-
whether sampling is performed with replacement and `order` dictates whether an
893-
ordered sample, also called a sequential sample, should be taken.
949+
whether sampling is performed with replacement. `ordered` dictates whether
950+
an ordered sample (also called a sequential sample, i.e. a sample where
951+
items appear in the same order as in `a`) should be taken.
894952
895953
Optionally specify a random number generator `rng` as the first argument
896954
(defaults to `Random.GLOBAL_RNG`).
@@ -923,8 +981,9 @@ wsample(a::AbstractArray, w::RealVector) = wsample(Random.GLOBAL_RNG, a, w)
923981
Select a weighted random sample of size `n` from `a` with probabilities proportional
924982
to the weights given in `w` if `a` is present, otherwise select a random sample of size
925983
`n` of the weights given in `w`. `replace` dictates whether sampling is performed with
926-
replacement and `order` dictates whether an ordered sample, also called a sequential
927-
sample, should be taken.
984+
replacement. `ordered` dictates whether
985+
an ordered sample (also called a sequential sample, i.e. a sample where
986+
items appear in the same order as in `a`) should be taken.
928987
929988
Optionally specify a random number generator `rng` as the first argument
930989
(defaults to `Random.GLOBAL_RNG`).

test/sampling.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,19 @@ end
2727

2828
#### sample with replacement
2929

30-
function check_sample_wrep(a::AbstractArray, vrgn, ptol::Real; ordered::Bool=false)
30+
function check_sample_wrep(a::AbstractArray, vrgn, ptol::Real; ordered::Bool=false, rev::Bool=false)
3131
vmin, vmax = vrgn
3232
(amin, amax) = extrema(a)
3333
@test vmin <= amin <= amax <= vmax
3434
n = vmax - vmin + 1
3535
p0 = fill(1/n, n)
3636
if ordered
37-
@test issorted(a)
37+
@test issorted(a; rev=rev)
3838
if ptol > 0
3939
@test isapprox(proportions(a, vmin:vmax), p0, atol=ptol)
4040
end
4141
else
42-
@test !issorted(a)
42+
@test !issorted(a; rev=rev)
4343
ncols = size(a,2)
4444
if ncols == 1
4545
@test isapprox(proportions(a, vmin:vmax), p0, atol=ptol)
@@ -68,11 +68,17 @@ test_rng_use(direct_sample!, 1:10, zeros(Int, 6))
6868
a = sample(3:12, n)
6969
check_sample_wrep(a, (3, 12), 5.0e-3; ordered=false)
7070

71-
a = sample(3:12, n; ordered=true)
72-
check_sample_wrep(a, (3, 12), 5.0e-3; ordered=true)
71+
for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
72+
r = rev ? reverse(3:12) : (3:12)
73+
r = T===Int ? r : T.(r)
74+
aa = Int.(sample(r, n; ordered=true))
75+
check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=true, rev=rev)
7376

74-
a = sample(3:12, 10; ordered=true)
75-
check_sample_wrep(a, (3, 12), 0; ordered=true)
77+
aa = Int.(sample(r, 10; ordered=true))
78+
check_sample_wrep(aa, (3, 12), 0; ordered=true, rev=rev)
79+
end
80+
81+
@test StatsBase._storeindices(1, 1, BigFloat) == StatsBase._storeindices(1, 1, BigFloat) == false
7682

7783
test_rng_use(sample, 1:10, 10)
7884

@@ -91,7 +97,7 @@ test_rng_use(samplepair, 1000)
9197

9298
#### sample without replacement
9399

94-
function check_sample_norep(a::AbstractArray, vrgn, ptol::Real; ordered::Bool=false)
100+
function check_sample_norep(a::AbstractArray, vrgn, ptol::Real; ordered::Bool=false, rev::Bool=false)
95101
# each column of a for one run
96102

97103
vmin, vmax = vrgn
@@ -103,7 +109,7 @@ function check_sample_norep(a::AbstractArray, vrgn, ptol::Real; ordered::Bool=fa
103109
aj = view(a,:,j)
104110
@assert allunique(aj)
105111
if ordered
106-
@assert issorted(aj)
112+
@assert issorted(aj, rev=rev)
107113
end
108114
end
109115

@@ -178,6 +184,9 @@ check_sample_norep(a, (3, 12), 0; ordered=false)
178184
a = sample(3:12, 5; replace=false, ordered=true)
179185
check_sample_norep(a, (3, 12), 0; ordered=true)
180186

187+
a = sample(reverse(3:12), 5; replace=false, ordered=true)
188+
check_sample_norep(a, (3, 12), 0; ordered=true, rev=true)
189+
181190
# tests of multidimensional sampling
182191

183192
a = sample(3:12, (2, 2); replace=false)

test/wsampling.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@ Random.seed!(1234)
55

66
#### weighted sample with replacement
77

8-
function check_wsample_wrep(a::AbstractArray, vrgn, wv::AbstractWeights, ptol::Real; ordered::Bool=false)
8+
function check_wsample_wrep(a::AbstractArray, vrgn, wv::AbstractWeights, ptol::Real;
9+
ordered::Bool=false, rev::Bool=false)
910
K = length(wv)
1011
(vmin, vmax) = vrgn
1112
(amin, amax) = extrema(a)
1213
@test vmin <= amin <= amax <= vmax
1314
p0 = wv ./ sum(wv)
15+
rev && reverse!(p0)
1416
if ordered
15-
@test issorted(a)
17+
@test issorted(a; rev=rev)
1618
if ptol > 0
1719
@test isapprox(proportions(a, vmin:vmax), p0, atol=ptol)
1820
end
1921
else
20-
@test !issorted(a)
22+
@test !issorted(a; rev=rev)
2123
ncols = size(a,2)
2224
if ncols == 1
2325
@test isapprox(proportions(a, vmin:vmax), p0, atol=ptol)
@@ -45,13 +47,20 @@ check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
4547
a = sample(4:7, wv, n; ordered=false)
4648
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
4749

48-
a = sample(4:7, wv, n; ordered=true)
49-
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=true)
50+
for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
51+
r = rev ? reverse(4:7) : (4:7)
52+
r = T===Int ? r : T.(r)
53+
aa = Int.(sample(r, wv, n; ordered=true))
54+
check_wsample_wrep(aa, (4, 7), wv, 5.0e-3; ordered=true, rev=rev)
55+
aa = Int.(sample(r, wv, 10; ordered=true))
56+
check_wsample_wrep(aa, (4, 7), wv, -1; ordered=true, rev=rev)
57+
end
5058

5159

5260
#### weighted sampling without replacement
5361

54-
function check_wsample_norep(a::AbstractArray, vrgn, wv::AbstractWeights, ptol::Real; ordered::Bool=false)
62+
function check_wsample_norep(a::AbstractArray, vrgn, wv::AbstractWeights, ptol::Real;
63+
ordered::Bool=false, rev::Bool=false)
5564
# each column of a for one run
5665

5766
vmin, vmax = vrgn
@@ -63,12 +72,13 @@ function check_wsample_norep(a::AbstractArray, vrgn, wv::AbstractWeights, ptol::
6372
aj = view(a,:,j)
6473
@assert allunique(aj)
6574
if ordered
66-
@assert issorted(aj)
75+
@assert issorted(aj; rev=rev)
6776
end
6877
end
6978

7079
if ptol > 0
7180
p0 = wv ./ sum(wv)
81+
rev && reverse!(p0)
7282
@test isapprox(proportions(a[1,:], vmin:vmax), p0, atol=ptol)
7383
end
7484
end
@@ -110,5 +120,9 @@ test_rng_use(efraimidis_aexpj_wsample_norep!, 4:7, wv, zeros(Int, 2))
110120
a = sample(4:7, wv, 3; replace=false, ordered=false)
111121
check_wsample_norep(a, (4, 7), wv, -1; ordered=false)
112122

113-
a = sample(4:7, wv, 3; replace=false, ordered=true)
114-
check_wsample_norep(a, (4, 7), wv, -1; ordered=true)
123+
for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
124+
r = rev ? reverse(4:7) : (4:7)
125+
r = T===Int ? r : T.(r)
126+
aa = Int.(sample(r, wv, 3; replace=false, ordered=true))
127+
check_wsample_norep(aa, (4, 7), wv, -1; ordered=true, rev=rev)
128+
end

0 commit comments

Comments
 (0)