Skip to content

Commit 1815e1e

Browse files
authored
Check that inputs and outputs do not share memory in sample! (#793)
Otherwise the result is incorrect. Unfortunately, there is no completely reliable way to check whether two arrays alias, but this is the best effort we can make. For custom array types like `AbstractWeights`, `Base.mightalias` falls back to using `Base.dataids`, which is less smart than the `SubArray` method, generating false positives when passing two views with the same parent but disjoint indices. In practice this is unlikely to matter, as using the same argument for values and weights is weird.
1 parent 43880cf commit 1815e1e

File tree

4 files changed

+29
-1
lines changed

4 files changed

+29
-1
lines changed

src/sampling.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,14 @@ items appear in the same order as in `a`) should be taken.
445445
446446
Optionally specify a random number generator `rng` as the first argument
447447
(defaults to `Random.GLOBAL_RNG`).
448+
449+
Output array `a` must not be the same object as `x` or `wv`
450+
nor share memory with them, or the result may be incorrect.
448451
"""
449452
function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
450453
replace::Bool=true, ordered::Bool=false)
454+
Base.mightalias(a, x) &&
455+
throw(ArgumentError("output array a must not share memory with input array x"))
451456
1 == firstindex(a) == firstindex(x) ||
452457
throw(ArgumentError("non 1-based arrays are not supported"))
453458
n = length(a)
@@ -893,6 +898,10 @@ efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::Abstra
893898

894899
function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
895900
replace::Bool=true, ordered::Bool=false)
901+
Base.mightalias(a, x) &&
902+
throw(ArgumentError("output array a must not share memory with input array x"))
903+
Base.mightalias(a, wv) &&
904+
throw(ArgumentError("output array a must not share memory with weights array wv"))
896905
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
897906
throw(ArgumentError("non 1-based arrays are not supported"))
898907
n = length(a)

src/weights.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ sum(wv::AbstractWeights) = wv.sum
2222
isempty(wv::AbstractWeights) = isempty(wv.values)
2323
size(wv::AbstractWeights) = size(wv.values)
2424

25+
Base.dataids(wv::AbstractWeights) = Base.dataids(wv.values)
26+
2527
Base.convert(::Type{Vector}, wv::AbstractWeights) = convert(Vector, wv.values)
2628

2729
@propagate_inbounds function Base.getindex(wv::AbstractWeights, i::Integer)
@@ -347,7 +349,7 @@ uweights(::Type{T}, s::Int) where {T<:Real} = UnitWeights{T}(s)
347349
"""
348350
varcorrection(w::UnitWeights, corrected=false)
349351
350-
* `corrected=true`: ``\\frac{1}{n - 1}``, where ``n`` is the length of the weight vector
352+
* `corrected=true`: ``\\frac{n}{n - 1}``, where ``n`` is the length of the weight vector
351353
* `corrected=false`: ``\\frac{1}{n}``, where ``n`` is the length of the weight vector
352354
353355
This definition is equivalent to the correction applied to unweighted data.

test/sampling.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,18 @@ test_same(replace=true, ordered=true)
245245
test_same(replace=false, ordered=true)
246246
test_same(replace=true, ordered=false)
247247
test_same(replace=false, ordered=false)
248+
249+
# Test that sample! throws when output shares memory with inputs
250+
x = rand(10)
251+
y = rand(10)
252+
@test_throws ArgumentError sample!(x, x)
253+
@test_throws ArgumentError sample!(x, weights(y), x)
254+
@test_throws ArgumentError sample!(x, weights(x), y)
255+
@test_throws ArgumentError sample!(x, weights(x), x)
256+
@test_throws ArgumentError sample!(view(x, 2:4), view(x, 3:5))
257+
@test_throws ArgumentError sample!(view(x, 2:4), weights(view(x, 3:5)), y)
258+
@test_throws ArgumentError sample!(view(x, 2:4), weights(view(x, 3:5)), view(x, 1:2))
259+
# These corner cases should theoretically succeed
260+
# but the second currently fails as Base.mightalias is not smart enough
261+
sample!(view(x, 2:4), view(x, 5:6))
262+
@test_broken sample!(view(x, 2:4), weights(view(x, 5:6)), y)

test/weights.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ weight_funcs = (weights, aweights, fweights, pweights)
2121
@test wv == w
2222
@test sum(wv) === 6.0
2323
@test !isempty(wv)
24+
@test Base.mightalias(w, wv)
25+
@test !Base.mightalias([1], wv)
2426

2527
b = trues(3)
2628
bv = f(b)

0 commit comments

Comments
 (0)