Skip to content

Commit 87f372c

Browse files
authored
functions that only support finite weights now throw errors for non-finites (#914)
* throw errors when only finite weights are supported * remove extra calls to sum(wv) * typo * typo * add a minimal test for custom weights implementations * fix new test on 1.0
1 parent c022f82 commit 87f372c

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

src/sampling.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,9 @@ Optionally specify a random number generator `rng` as the first argument
591591
function sample(rng::AbstractRNG, wv::AbstractWeights)
592592
1 == firstindex(wv) ||
593593
throw(ArgumentError("non 1-based arrays are not supported"))
594-
t = rand(rng) * sum(wv)
594+
wsum = sum(wv)
595+
isfinite(wsum) || throw(ArgumentError("only finite weights are supported"))
596+
t = rand(rng) * wsum
595597
n = length(wv)
596598
i = 1
597599
cw = wv[1]
@@ -654,6 +656,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights,
654656
throw(ArgumentError("output array x must not share memory with input array a"))
655657
1 == firstindex(a) == firstindex(wv) ||
656658
throw(ArgumentError("non 1-based arrays are not supported"))
659+
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
657660
length(wv) == length(a) || throw(DimensionMismatch("Inconsistent lengths."))
658661

659662
# create alias table
@@ -688,13 +691,14 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
688691
throw(ArgumentError("output array x must not share memory with weights array wv"))
689692
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
690693
throw(ArgumentError("non 1-based arrays are not supported"))
694+
wsum = sum(wv)
695+
isfinite(wsum) || throw(ArgumentError("only finite weights are supported"))
691696
n = length(a)
692697
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
693698
k = length(x)
694699

695700
w = Vector{Float64}(undef, n)
696701
copyto!(w, wv)
697-
wsum = sum(wv)
698702

699703
for i = 1:k
700704
u = rand(rng) * wsum
@@ -734,6 +738,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
734738
throw(ArgumentError("output array x must not share memory with weights array wv"))
735739
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
736740
throw(ArgumentError("non 1-based arrays are not supported"))
741+
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
737742
n = length(a)
738743
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
739744
k = length(x)
@@ -775,6 +780,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
775780
throw(ArgumentError("output array x must not share memory with weights array wv"))
776781
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
777782
throw(ArgumentError("non 1-based arrays are not supported"))
783+
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
778784
n = length(a)
779785
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
780786
k = length(x)
@@ -848,6 +854,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
848854
throw(ArgumentError("output array x must not share memory with weights array wv"))
849855
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
850856
throw(ArgumentError("non 1-based arrays are not supported"))
857+
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
851858
n = length(a)
852859
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
853860
k = length(x)

src/scalarstats.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ end
163163
# Weighted mode of arbitrary vectors of values
164164
function mode(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real
165165
isempty(a) && throw(ArgumentError("mode is not defined for empty collections"))
166+
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
166167
length(a) == length(wv) ||
167168
throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))"))
168169

@@ -184,6 +185,7 @@ end
184185

185186
function modes(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real
186187
isempty(a) && throw(ArgumentError("mode is not defined for empty collections"))
188+
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
187189
length(a) == length(wv) ||
188190
throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))"))
189191

src/weights.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,7 @@ function quantile(v::AbstractVector{<:Real}{V}, w::AbstractWeights{W}, p::Abstra
716716
# checks
717717
isempty(v) && throw(ArgumentError("quantile of an empty array is undefined"))
718718
isempty(p) && throw(ArgumentError("empty quantile array"))
719+
isfinite(sum(w)) || throw(ArgumentError("only finite weights are supported"))
719720
all(x -> 0 <= x <= 1, p) || throw(ArgumentError("input probability out of [0,1] range"))
720721

721722
w.sum == 0 && throw(ArgumentError("weight vector cannot sum to zero"))

test/weights.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
using StatsBase
22
using LinearAlgebra, Random, SparseArrays, Test
33

4+
5+
# minimal custom weights type for tests below
6+
struct MyWeights <: AbstractWeights{Float64, Float64, Vector{Float64}}
7+
values::Vector{Float64}
8+
sum::Float64
9+
end
10+
MyWeights(values) = MyWeights(values, sum(values))
11+
12+
413
@testset "StatsBase.Weights" begin
514
weight_funcs = (weights, aweights, fweights, pweights)
615

@@ -610,4 +619,11 @@ end
610619
end
611620
end
612621

622+
@testset "custom weight types" begin
623+
@test mean([1, 2, 3], MyWeights([1, 4, 10])) 2.6
624+
@test mean([1, 2, 3], MyWeights([NaN, 4, 10])) |> isnan
625+
@test mode([1, 2, 3], MyWeights([1, 4, 10])) == 3
626+
@test_throws ArgumentError mode([1, 2, 3], MyWeights([NaN, 4, 10]))
627+
end
628+
613629
end # @testset StatsBase.Weights

0 commit comments

Comments
 (0)