Skip to content

Commit f7129c9

Browse files
authored
Throw an error for OffsetArrays and aliased arrays when sampling (#833)
Add checks to all sampling methods, as they are mentioned in the manual even if not exported. Remove redundant check for aliasing from main method as it is potentially costly. Also fix existing check, as it confused the input and output array (the order of argument is not standard).
1 parent 390607a commit f7129c9

File tree

3 files changed

+124
-22
lines changed

3 files changed

+124
-22
lines changed

src/sampling.jl

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ using Random: Sampler, Random.GLOBAL_RNG
1010
### Algorithms for sampling with replacement
1111

1212
function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray)
13+
1 == firstindex(a) == firstindex(x) ||
14+
throw(ArgumentError("non 1-based arrays are not supported"))
1315
s = Sampler(rng, 1:length(a))
1416
b = a[1] - 1
1517
if b == 0
@@ -34,6 +36,10 @@ and set `x[j] = a[i]`, with `n=length(a)` and `k=length(x)`.
3436
This algorithm consumes `k` random numbers.
3537
"""
3638
function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
39+
1 == firstindex(a) == firstindex(x) ||
40+
throw(ArgumentError("non 1-based arrays are not supported"))
41+
Base.mightalias(a, x) &&
42+
throw(ArgumentError("output array x must not share memory with input array a"))
3743
s = Sampler(rng, 1:length(a))
3844
for i = 1:length(x)
3945
@inbounds x[i] = a[rand(rng, s)]
@@ -55,6 +61,10 @@ storeindices(n, k, T) = false
5561

5662
# order results of a sampler that does not order automatically
5763
function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
64+
1 == firstindex(a) == firstindex(x) ||
65+
throw(ArgumentError("non 1-based arrays are not supported"))
66+
Base.mightalias(a, x) &&
67+
throw(ArgumentError("output array x must not share memory with input array a"))
5868
n, k = length(a), length(x)
5969
# todo: if eltype(x) <: Real && eltype(a) <: Real,
6070
# in some cases it might be faster to check
@@ -130,6 +140,10 @@ memory space. Suitable for the case where memory is tight.
130140
"""
131141
function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
132142
initshuffle::Bool=true)
143+
1 == firstindex(a) == firstindex(x) ||
144+
throw(ArgumentError("non 1-based arrays are not supported"))
145+
Base.mightalias(a, x) &&
146+
throw(ArgumentError("output array x must not share memory with input array a"))
133147
n = length(a)
134148
k = length(x)
135149
k <= n || error("length(x) should not exceed length(a)")
@@ -186,6 +200,10 @@ faster than Knuth's algorithm especially when `n` is greater than `k`.
186200
It is ``O(n)`` for initialization, plus ``O(k)`` for random shuffling
187201
"""
188202
function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
203+
1 == firstindex(a) == firstindex(x) ||
204+
throw(ArgumentError("non 1-based arrays are not supported"))
205+
Base.mightalias(a, x) &&
206+
throw(ArgumentError("output array x must not share memory with input array a"))
189207
n = length(a)
190208
k = length(x)
191209
k <= n || error("length(x) should not exceed length(a)")
@@ -222,6 +240,10 @@ However, if `k` is large and approaches ``n``, the rejection rate would increase
222240
drastically, resulting in poorer performance.
223241
"""
224242
function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
243+
1 == firstindex(a) == firstindex(x) ||
244+
throw(ArgumentError("non 1-based arrays are not supported"))
245+
Base.mightalias(a, x) &&
246+
throw(ArgumentError("output array x must not share memory with input array a"))
225247
n = length(a)
226248
k = length(x)
227249
k <= n || error("length(x) should not exceed length(a)")
@@ -260,6 +282,10 @@ This algorithm consumes ``O(n)`` random numbers, with `n=length(a)`.
260282
The outputs are ordered.
261283
"""
262284
function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
285+
1 == firstindex(a) == firstindex(x) ||
286+
throw(ArgumentError("non 1-based arrays are not supported"))
287+
Base.mightalias(a, x) &&
288+
throw(ArgumentError("output array x must not share memory with input array a"))
263289
n = length(a)
264290
k = length(x)
265291
k <= n || error("length(x) should not exceed length(a)")
@@ -298,6 +324,10 @@ This algorithm consumes ``O(k^2)`` random numbers, with `k=length(x)`.
298324
The outputs are ordered.
299325
"""
300326
function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
327+
1 == firstindex(a) == firstindex(x) ||
328+
throw(ArgumentError("non 1-based arrays are not supported"))
329+
Base.mightalias(a, x) &&
330+
throw(ArgumentError("output array x must not share memory with input array a"))
301331
n = length(a)
302332
k = length(x)
303333
k <= n || error("length(x) should not exceed length(a)")
@@ -340,6 +370,10 @@ This algorithm consumes ``O(k)`` random numbers, with `k=length(x)`.
340370
The outputs are ordered.
341371
"""
342372
function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
373+
1 == firstindex(a) == firstindex(x) ||
374+
throw(ArgumentError("non 1-based arrays are not supported"))
375+
Base.mightalias(a, x) &&
376+
throw(ArgumentError("output array x must not share memory with input array a"))
343377
N = length(a)
344378
n = length(x)
345379
n <= N || error("length(x) should not exceed length(a)")
@@ -451,8 +485,6 @@ nor share memory with them, or the result may be incorrect.
451485
"""
452486
function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
453487
replace::Bool=true, ordered::Bool=false)
454-
Base.mightalias(a, x) &&
455-
throw(ArgumentError("output array a must not share memory with input array x"))
456488
1 == firstindex(a) == firstindex(x) ||
457489
throw(ArgumentError("non 1-based arrays are not supported"))
458490
n = length(a)
@@ -550,6 +582,8 @@ Optionally specify a random number generator `rng` as the first argument
550582
(defaults to `Random.GLOBAL_RNG`).
551583
"""
552584
function sample(rng::AbstractRNG, wv::AbstractWeights)
585+
1 == firstindex(wv) ||
586+
throw(ArgumentError("non 1-based arrays are not supported"))
553587
t = rand(rng) * sum(wv)
554588
n = length(wv)
555589
i = 1
@@ -579,6 +613,12 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm:
579613
"""
580614
function direct_sample!(rng::AbstractRNG, a::AbstractArray,
581615
wv::AbstractWeights, x::AbstractArray)
616+
Base.mightalias(a, x) &&
617+
throw(ArgumentError("output array x must not share memory with input array a"))
618+
Base.mightalias(x, wv) &&
619+
throw(ArgumentError("output array x must not share memory with weights array wv"))
620+
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
621+
throw(ArgumentError("non 1-based arrays are not supported"))
582622
n = length(a)
583623
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
584624
for i = 1:length(x)
@@ -662,6 +702,12 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` ti
662702
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers.
663703
"""
664704
function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
705+
Base.mightalias(a, x) &&
706+
throw(ArgumentError("output array x must not share memory with input array a"))
707+
Base.mightalias(x, wv) &&
708+
throw(ArgumentError("output array x must not share memory with weights array wv"))
709+
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
710+
throw(ArgumentError("non 1-based arrays are not supported"))
665711
n = length(a)
666712
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
667713

@@ -694,6 +740,12 @@ and has overall time complexity ``O(n k)``.
694740
"""
695741
function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
696742
wv::AbstractWeights, x::AbstractArray)
743+
Base.mightalias(a, x) &&
744+
throw(ArgumentError("output array x must not share memory with input array a"))
745+
Base.mightalias(x, wv) &&
746+
throw(ArgumentError("output array x must not share memory with weights array wv"))
747+
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
748+
throw(ArgumentError("non 1-based arrays are not supported"))
697749
n = length(a)
698750
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
699751
k = length(x)
@@ -734,6 +786,12 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers.
734786
"""
735787
function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
736788
wv::AbstractWeights, x::AbstractArray)
789+
Base.mightalias(a, x) &&
790+
throw(ArgumentError("output array x must not share memory with input array a"))
791+
Base.mightalias(x, wv) &&
792+
throw(ArgumentError("output array x must not share memory with weights array wv"))
793+
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
794+
throw(ArgumentError("non 1-based arrays are not supported"))
737795
n = length(a)
738796
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
739797
k = length(x)
@@ -769,6 +827,12 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers.
769827
"""
770828
function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
771829
wv::AbstractWeights, x::AbstractArray)
830+
Base.mightalias(a, x) &&
831+
throw(ArgumentError("output array x must not share memory with input array a"))
832+
Base.mightalias(x, wv) &&
833+
throw(ArgumentError("output array x must not share memory with weights array wv"))
834+
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
835+
throw(ArgumentError("non 1-based arrays are not supported"))
772836
n = length(a)
773837
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
774838
k = length(x)
@@ -836,6 +900,12 @@ processing time to draw ``k`` elements. It consumes ``O(k \\log(n / k))`` random
836900
function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
837901
wv::AbstractWeights, x::AbstractArray;
838902
ordered::Bool=false)
903+
Base.mightalias(a, x) &&
904+
throw(ArgumentError("output array x must not share memory with input array a"))
905+
Base.mightalias(x, wv) &&
906+
throw(ArgumentError("output array x must not share memory with weights array wv"))
907+
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
908+
throw(ArgumentError("non 1-based arrays are not supported"))
839909
n = length(a)
840910
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
841911
k = length(x)
@@ -898,10 +968,6 @@ efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::Abstra
898968

899969
function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
900970
replace::Bool=true, ordered::Bool=false)
901-
Base.mightalias(a, x) &&
902-
throw(ArgumentError("output array a must not share memory with input array x"))
903-
Base.mightalias(a, wv) &&
904-
throw(ArgumentError("output array a must not share memory with weights array wv"))
905971
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
906972
throw(ArgumentError("non 1-based arrays are not supported"))
907973
n = length(a)

test/sampling.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using StatsBase
2-
using Test, Random, StableRNGs
2+
using Test, Random, StableRNGs, OffsetArrays
33

44
Random.seed!(1234)
55

@@ -246,17 +246,23 @@ test_same(replace=false, ordered=true)
246246
test_same(replace=true, ordered=false)
247247
test_same(replace=false, ordered=false)
248248

249-
# Test that sample! throws when output shares memory with inputs
250-
x = rand(10)
251-
y = rand(10)
252-
@test_throws ArgumentError sample!(x, x)
253-
@test_throws ArgumentError sample!(x, weights(y), x)
254-
@test_throws ArgumentError sample!(x, weights(x), y)
255-
@test_throws ArgumentError sample!(x, weights(x), x)
256-
@test_throws ArgumentError sample!(view(x, 2:4), view(x, 3:5))
257-
@test_throws ArgumentError sample!(view(x, 2:4), weights(view(x, 3:5)), y)
258-
@test_throws ArgumentError sample!(view(x, 2:4), weights(view(x, 3:5)), view(x, 1:2))
259-
# These corner cases should theoretically succeed
260-
# but the second currently fails as Base.mightalias is not smart enough
261-
sample!(view(x, 2:4), view(x, 5:6))
262-
@test_broken sample!(view(x, 2:4), weights(view(x, 5:6)), y)
249+
@testset "validation of inputs" begin
250+
for f in (sample!, knuths_sample!, fisher_yates_sample!, self_avoid_sample!,
251+
seqsample_a!, seqsample_c!, seqsample_d!)
252+
x = rand(10)
253+
y = rand(10)
254+
ox = OffsetArray(x, -4:5)
255+
oy = OffsetArray(y, -4:5)
256+
257+
# Test that offset arrays throw an error
258+
@test_throws ArgumentError f(ox, y)
259+
@test_throws ArgumentError f(x, oy)
260+
@test_throws ArgumentError f(ox, oy)
261+
262+
# Test that an error is thrown when output shares memory with inputs
263+
@test_throws ArgumentError f(x, x)
264+
@test_throws ArgumentError f(view(x, 2:4), view(x, 3:5))
265+
# This corner case should succeed
266+
f(view(x, 2:4), view(x, 5:6))
267+
end
268+
end

test/wsampling.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using StatsBase
2-
using Random, Test
2+
using Random, Test, OffsetArrays
33

44
Random.seed!(1234)
55

@@ -133,3 +133,33 @@ for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64
133133
aa = Int.(sample(r, wv, 3; replace=false, ordered=true))
134134
check_wsample_norep(aa, (4, 7), wv, -1; ordered=true, rev=rev)
135135
end
136+
137+
@testset "validation of inputs" begin
138+
x = rand(10)
139+
y = rand(10)
140+
z = rand(10)
141+
ox = OffsetArray(x, -4:5)
142+
oy = OffsetArray(y, -4:5)
143+
oz = OffsetArray(z, -4:5)
144+
145+
@test_throws ArgumentError sample(weights(ox))
146+
147+
for f in (sample!, wsample!, naive_wsample_norep!, efraimidis_a_wsample_norep!,
148+
efraimidis_ares_wsample_norep!, efraimidis_aexpj_wsample_norep!)
149+
# Test that offset arrays throw an error
150+
@test_throws ArgumentError f(ox, weights(y), z)
151+
@test_throws ArgumentError f(x, weights(oy), z)
152+
@test_throws ArgumentError f(x, weights(y), oz)
153+
@test_throws ArgumentError f(ox, weights(oy), oz)
154+
155+
# Test that an error is thrown when output shares memory with inputs
156+
@test_throws ArgumentError f(x, weights(y), x)
157+
@test_throws ArgumentError f(y, weights(x), x)
158+
@test_throws ArgumentError f(x, weights(x), x)
159+
@test_throws ArgumentError f(y, weights(view(x, 3:5)), view(x, 2:4))
160+
@test_throws ArgumentError f(view(x, 2:4), weights(view(x, 3:5)), view(x, 1:2))
161+
# This corner case should theoretically succeed
162+
# but it currently fails as Base.mightalias is not smart enough
163+
@test_broken f(y, weights(view(x, 5:6)), view(x, 2:4))
164+
end
165+
end

0 commit comments

Comments
 (0)