Skip to content

Commit 390607a

Browse files
authored
Disallow NaN and Inf values in AbstractWeights (#814)
1 parent 2dc23d7 commit 390607a

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

src/weights.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@ macro weights(name)
1212
mutable struct $name{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractWeights{S, T, V}
1313
values::V
1414
sum::S
15+
function $(esc(name)){S, T, V}(values, sum) where {S<:Real, T<:Real, V<:AbstractVector{T}}
16+
isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values"))
17+
return new{S, T, V}(values, sum)
18+
end
1519
end
20+
$(esc(name))(values::AbstractVector{T}, sum::S) where {S<:Real, T<:Real} = $(esc(name)){S, T, typeof(values)}(values, sum)
1621
$(esc(name))(values::AbstractVector{<:Real}) = $(esc(name))(values, sum(values))
1722
end
1823
end
@@ -44,8 +49,10 @@ Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values),
4449

4550
@propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int)
4651
s = v - wv[i]
52+
sum = wv.sum + s
53+
isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values"))
4754
wv.values[i] = v
48-
wv.sum += s
55+
wv.sum = sum
4956
v
5057
end
5158

@@ -707,7 +714,6 @@ function quantile(v::RealVector{V}, w::AbstractWeights{W}, p::RealVector) where
707714
length(v) == length(w) || throw(ArgumentError("data and weight vectors must be the same size," *
708715
"got $(length(v)) and $(length(w))"))
709716
for x in w.values
710-
isnan(x) && throw(ArgumentError("weight vector cannot contain NaN entries"))
711717
x < 0 && throw(ArgumentError("weight vector cannot contain negative entries"))
712718
end
713719

test/weights.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ weight_funcs = (weights, aweights, fweights, pweights)
3939

4040
@test sum(ba, wv) === 4.0
4141
@test sum(sa, wv) === 7.0
42+
43+
@test_throws ArgumentError f([0.1, Inf])
44+
@test_throws ArgumentError f([0.1, NaN])
45+
4246
end
4347

4448
@testset "$f, setindex!" for f in weight_funcs
@@ -63,6 +67,9 @@ end
6367
@test sum(wv) === 11.
6468
@test wv == [3., 5., 3.] # Test state of all values
6569

70+
@test_throws ArgumentError wv[1] = Inf
71+
@test_throws ArgumentError wv[1] = NaN
72+
6673
# Test failed setindex! due to conversion error
6774
w = [1, 2, 3]
6875
wv = f(w)
@@ -90,11 +97,6 @@ end
9097
@test x != y
9198
end
9299

93-
x = f([1, 2, NaN]) # isequal and == treat NaN differently
94-
y = f([1, 2, NaN])
95-
@test isequal(x, y)
96-
@test x != y
97-
98100
x = f([1.0, 2.0, 0.0]) # isequal and == treat ±0.0 differently
99101
y = f([1.0, 2.0, -0.0])
100102
@test !isequal(x, y)

0 commit comments

Comments
 (0)