Skip to content

Commit 6d09642

Browse files
authored
Merge pull request #20 from JuliaAI/dev
For a 0.1.3 release
2 parents be4bb23 + c240303 commit 6d09642

File tree

8 files changed

+51
-7
lines changed

8 files changed

+51
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/arrays.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ function Base.Broadcast.broadcasted(::typeof(mode),
243243
mode_flat = map(1:length(u)) do i
244244
max_prob = maximum(dic[ref][i] for ref in keys(dic))
245245
m = zero(R)
246+
247+
# `maximum` of any iterable containing `NaN` would return `NaN`
248+
# For this case the index `m` won't be updated in the loop as relations
249+
# involving NaN as one of it's argument always returns false
250+
# (e.g `==(NaN, NaN)` returns false)
251+
throw_nan_error_if_needed(max_prob)
246252
for ref in keys(dic)
247253
if dic[ref][i] == max_prob
248254
m = ref

src/methods.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ function Dist.mode(d::UnivariateFinite)
196196
p = values(dic)
197197
max_prob = maximum(p)
198198
m = first(first(dic)) # mode, just some ref for now
199+
200+
# `maximum` of any iterable containing `NaN` would return `NaN`
201+
# For this case the index `m` won't be updated in the loop below as relations
202+
# involving NaN as one of it's arguments always returns false
203+
# (e.g `==(NaN, NaN)` returns false)
204+
throw_nan_error_if_needed(max_prob)
199205
for (x, prob) in dic
200206
if prob == max_prob
201207
m = x
@@ -205,6 +211,18 @@ function Dist.mode(d::UnivariateFinite)
205211
return d.decoder(m)
206212
end
207213

214+
function throw_nan_error_if_needed(x)
215+
if isnan(x)
216+
throw(
217+
DomainError(
218+
NaN,
219+
"`mode` is invalid for `UnivariateFininite` distribution "*
220+
"with `pdf` containing `NaN`s"
221+
)
222+
)
223+
end
224+
end
225+
208226
# mode(v::Vector{UnivariateFinite}) = mode.(v)
209227
# mode(u::UnivariateFiniteVector{2}) =
210228
# [u.support[ifelse(s > 0.5, 2, 1)] for s in u.scores]

src/utilities.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ always true.
2121
2222
Not to be confused with `levels(x.pool)`. See the example below.
2323
24-
Also, overloaded for `x` a `CategoricalArray` or `CategoricalPool`.
24+
Also, overloaded for `x` a `CategoricalArray`, `CategoricalPool`, and
25+
for views of `CategoricalArray`.
2526
2627
**Private method.*
2728
@@ -57,7 +58,7 @@ Also, overloaded for `x` a `CategoricalArray` or `CategoricalPool`.
5758
classes(p::CategoricalPool) = [p[i] for i in 1:length(p)]
5859
classes(x::CategoricalValue) = classes(CategoricalArrays.pool(x))
5960
classes(v::CategoricalArray) = classes(CategoricalArrays.pool(v))
60-
61+
classes(v::SubArray{<:Any, <:Any, <:CategoricalArray}) = classes(parent(v))
6162

6263
# # CATEGORICAL VALUES TO INTEGERS
6364

test/arrays.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,17 @@ end
198198
u = UnivariateFinite(P, pool=missing)
199199
expected = mode.([u...])
200200
@test all(mode.(u) .== expected)
201+
202+
# `mode` broadcasting of `Univariate` objects containing `NaN` in probs.
203+
unf_arr = UnivariateFinite(
204+
[
205+
0.1 0.2 NaN 0.1 NaN;
206+
0.2 0.1 0.1 0.4 0.2;
207+
0.3 NaN 0.2 NaN 0.3
208+
],
209+
pool=missing
210+
)
211+
@test_throws DomainError mode.(unf_arr)
201212
end
202213

203214
@testset "cat for UnivariateFiniteArray" begin

test/methods.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ end
175175
p[42] = 0.5
176176
d = UnivariateFinite(v, p)
177177
@test mode(d) == 42
178+
179+
# `mode` of `Univariate` objects containing `NaN` in probs.
180+
unf = UnivariateFinite([0.1, 0.2, NaN, 0.1, NaN], pool=missing)
181+
@test_throws DomainError mode(unf)
178182
end
179183

180184
@testset "UnivariateFinite methods" begin

test/types.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ end
3636
UnivariateFinite(supp, probs, pool=missing, augment=true);
3737

3838
# dimension mismatches:
39-
badprobs = rand(40, 3)
39+
badprobs = rand(rng, 40, 3)
4040
@test_throws(CategoricalDistributions.err_dim(supp, badprobs),
4141
UnivariateFinite(supp, badprobs, pool=missing))
4242

@@ -60,7 +60,7 @@ end
6060
probs = probs ./ sum(probs)
6161
u = UnivariateFinite(probs, pool=missing);
6262
@test u isa UnivariateFinite
63-
probs = rand(10, 2)
63+
probs = rand(rng, 10, 2)
6464
probs = probs ./ sum(probs, dims=2)
6565
u = UnivariateFinite(probs, pool=missing);
6666
@test u.scitype == Multiclass{2}
@@ -73,9 +73,11 @@ end
7373

7474
v = categorical(1:3)
7575
@test_logs((:warn, r"Ignoring"),
76-
UnivariateFinite(v[1:2], rand(3), augment=true, pool=missing));
76+
UnivariateFinite(v[1:2], rand(rng, 3),
77+
augment=true, pool=missing));
7778
@test_logs((:warn, r"Ignoring"),
78-
UnivariateFinite(v[1:2], rand(3), augment=true, ordered=true));
79+
UnivariateFinite(v[1:2], rand(rng, 3),
80+
augment=true, ordered=true));
7981

8082
# using `UnivariateFiniteArray` as a constructor just falls back
8183
# to `UnivariateFinite` constructor:

test/utilities.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import CategoricalDistributions: classes, transform, decoder, int
1717
levels!(v, reverse(levels(v)))
1818
@test classes(v[1]) == levels(v)
1919
@test classes(v) == levels(v)
20+
vsub = view(v, 1:2)
21+
@test classes(vsub) == classes(v)
2022
end
2123

2224
@testset "int, classes, decoder" begin

0 commit comments

Comments
 (0)