Skip to content

Commit b9d86c3

Browse files
committed
address #16
1 parent 149afd3 commit b9d86c3

File tree

4 files changed

+39
-0
lines changed

4 files changed

+39
-0
lines changed

src/arrays.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ function Base.Broadcast.broadcasted(::typeof(mode),
243243
mode_flat = map(1:length(u)) do i
244244
max_prob = maximum(dic[ref][i] for ref in keys(dic))
245245
m = zero(R)
246+
247+
# `maximum` of any iterable containing `NaN` would return `NaN`
248+
# For this case the index `m` won't be updated in the loop as relations
249+
# involving NaN as one of it's argument always returns false
250+
# (e.g `==(NaN, NaN)` returns false)
251+
throw_nan_error_if_needed(max_prob)
246252
for ref in keys(dic)
247253
if dic[ref][i] == max_prob
248254
m = ref

src/methods.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ function Dist.mode(d::UnivariateFinite)
196196
p = values(dic)
197197
max_prob = maximum(p)
198198
m = first(first(dic)) # mode, just some ref for now
199+
200+
# `maximum` of any iterable containing `NaN` would return `NaN`
201+
# For this case the index `m` won't be updated in the loop below as relations
202+
# involving NaN as one of it's arguments always returns false
203+
# (e.g `==(NaN, NaN)` returns false)
204+
throw_nan_error_if_needed(max_prob)
199205
for (x, prob) in dic
200206
if prob == max_prob
201207
m = x
@@ -205,6 +211,18 @@ function Dist.mode(d::UnivariateFinite)
205211
return d.decoder(m)
206212
end
207213

214+
function throw_nan_error_if_needed(x)
215+
if isnan(x)
216+
throw(
217+
DomainError(
218+
NaN,
219+
"`mode` is invalid for `UnivariateFininite` distribution "*
220+
"with `pdf` containing `NaN`s"
221+
)
222+
)
223+
end
224+
end
225+
208226
# mode(v::Vector{UnivariateFinite}) = mode.(v)
209227
# mode(u::UnivariateFiniteVector{2}) =
210228
# [u.support[ifelse(s > 0.5, 2, 1)] for s in u.scores]

test/arrays.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,17 @@ end
198198
u = UnivariateFinite(P, pool=missing)
199199
expected = mode.([u...])
200200
@test all(mode.(u) .== expected)
201+
202+
# `mode` broadcasting of `Univariate` objects containing `NaN` in probs.
203+
unf_arr = UnivariateFinite(
204+
[
205+
0.1 0.2 NaN 0.1 NaN;
206+
0.2 0.1 0.1 0.4 0.2;
207+
0.3 NaN 0.2 NaN 0.3
208+
],
209+
pool=missing
210+
)
211+
@test_throws DomainError mode.(unf_arr)
201212
end
202213

203214
@testset "cat for UnivariateFiniteArray" begin

test/methods.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ end
175175
p[42] = 0.5
176176
d = UnivariateFinite(v, p)
177177
@test mode(d) == 42
178+
179+
# `mode` of `Univariate` objects containing `NaN` in probs.
180+
unf = UnivariateFinite([0.1, 0.2, NaN, 0.1, NaN], pool=missing)
181+
@test_throws DomainError mode(unf)
178182
end
179183

180184
@testset "UnivariateFinite methods" begin

0 commit comments

Comments
 (0)