Skip to content

Commit afcd606

Browse files
authored
Merge pull request #61 from JuliaAI/dev
For a 0.1.11 release
2 parents f691205 + a07635d commit afcd606

File tree

7 files changed

+117
-25
lines changed

7 files changed

+117
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.10"
4+
version = "0.1.11"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Arrays of `UnivariateFinite` distributions are defined using the same
8282
constructor. Broadcasting methods, such as `pdf`, are optimized for
8383
such arrays:
8484

85-
```
85+
```julia
8686
julia> v = UnivariateFinite(["no", "yes"], [0.1, 0.2, 0.3, 0.4], augment=true, pool=data)
8787
4-element UnivariateFiniteArray{Multiclass{3}, String, UInt32, Float64, 1}:
8888
UnivariateFinite{Multiclass{3}}(no=>0.9, yes=>0.1)
@@ -119,7 +119,6 @@ julia> pdf(v, L)
119119
0.0 0.6 0.4
120120
```
121121

122-
123122
## Measures over finite labeled sets
124123

125124
There is, in fact, no enforcement that probabilities in a
@@ -128,7 +127,6 @@ to a type `T` for which `zero(T)` is defined. In particular
128127
`UnivariateFinite` objects implement arbitrary non-negative, signed,
129128
or complex measures over a finite labeled set.
130129

131-
132130
## What does this package provide?
133131

134132
- A new type `UnivariateFinite{S}` for representing probability
@@ -144,7 +142,7 @@ or complex measures over a finite labeled set.
144142
- Implementations of `rand` for generating random samples of a
145143
`UnivariateFinite` distribution.
146144

147-
- Implementations of the `pdf`, `logpdf` and `mode` methods of
145+
- Implementations of the `pdf`, `logpdf`, `mode` and `modes` methods of
148146
Distributions.jl, with efficient broadcasting over the new array
149147
type.
150148

src/CategoricalDistributions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Random
1616

1717
const Dist = Distributions
1818

19-
import Distributions: pdf, logpdf, support, mode
19+
import Distributions: pdf, logpdf, support, mode, modes
2020

2121
include("utilities.jl")
2222
include("types.jl")
@@ -28,7 +28,7 @@ include("arithmetic.jl")
2828
export UnivariateFinite, UnivariateFiniteArray, UnivariateFiniteVector
2929

3030
# re-eport from Distributions:
31-
export pdf, logpdf, support, mode
31+
export pdf, logpdf, support, mode, modes
3232

3333
# re-export from ScientificTypesBase:
3434
export Multiclass, OrderedFactor

src/arrays.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ Base.Broadcast.broadcasted(
271271
c::Missing) where {S,V,R,P,N} = Missings.missings(P, length(u))
272272

273273

274-
## PERFORMANT BROADCASTING OF mode:
274+
## PERFORMANT BROADCASTING OF mode(s):
275275

276276
function Base.Broadcast.broadcasted(::typeof(mode),
277277
u::UniFinArr{S,V,R,P,N}) where {S,V,R,P,N}
@@ -298,6 +298,26 @@ function Base.Broadcast.broadcasted(::typeof(mode),
298298
return reshape(mode_flat, size(u))
299299
end
300300

301+
function Base.Broadcast.broadcasted(::typeof(modes),
302+
u::UniFinArr{S,V,R,P,N}) where {S,V,R,P,N}
303+
dic = u.prob_given_ref
304+
305+
# using linear indexing:
306+
mode_flat = map(1:length(u)) do i
307+
max_prob = maximum(dic[ref][i] for ref in keys(dic))
308+
M = R[]
309+
310+
# see comment for in broadcasted(::mode) above
311+
throw_nan_error_if_needed(max_prob)
312+
for ref in keys(dic)
313+
if dic[ref][i] == max_prob
314+
push!(M, ref)
315+
end
316+
end
317+
return u.decoder(M)
318+
end
319+
return reshape(mode_flat, size(u))
320+
end
301321

302322
## EXTENSION OF CLASSES TO ARRAYS OF UNIVARIATE FINITE
303323

src/methods.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ end
120120
# TODO: It would be useful to define == as well.
121121

122122
"""
123-
Dist.pdf(d::UnivariateFinite, x)
123+
Distributions.pdf(d::UnivariateFinite, x)
124124
125125
Probability of `d` at `x`.
126126
@@ -178,15 +178,31 @@ function Dist.mode(d::UnivariateFinite)
178178
return d.decoder(m)
179179
end
180180

181+
function Dist.modes(d::UnivariateFinite{S,V,R,P}) where {S,V,R,P}
182+
dic = d.prob_given_ref
183+
p = values(dic)
184+
max_prob = maximum(p)
185+
M = R[] # modes
186+
187+
# see comment in `mode` above
188+
throw_nan_error_if_needed(max_prob)
189+
for (x, prob) in dic
190+
if prob == max_prob
191+
push!(M, x)
192+
end
193+
end
194+
return d.decoder(M)
195+
end
196+
197+
const ERR_NAN_FOUND = DomainError(
198+
NaN,
199+
"`mode(s)` is invalid for a `UnivariateFinite` distribution "*
200+
"with `pdf` containing `NaN`s"
201+
)
202+
181203
function throw_nan_error_if_needed(x)
182204
if isnan(x)
183-
throw(
184-
DomainError(
185-
NaN,
186-
"`mode` is invalid for `UnivariateFininite` distribution "*
187-
"with `pdf` containing `NaN`s"
188-
)
189-
)
205+
throw(ERR_NAN_FOUND)
190206
end
191207
end
192208

test/arrays.jl

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import Random
1010
using Missings
1111
using ScientificTypes
1212

13-
import CategoricalDistributions: classes
13+
import CategoricalDistributions: classes, ERR_NAN_FOUND
1414
import CategoricalArrays.unwrap
1515

1616
rng = StableRNG(111)
@@ -198,10 +198,10 @@ end
198198
probs = rand(rng, n)
199199
u = UnivariateFinite(probs, augment = true, pool=missing)
200200
supp = Distributions.support(u)
201-
modes = mode.(u)
202-
@test modes isa CategoricalArray
201+
_modes = mode.(u)
202+
@test _modes isa CategoricalArray
203203
expected = [ifelse(p > 0.5, supp[2], supp[1]) for p in probs]
204-
@test all(modes .== expected)
204+
@test all(_modes .== expected)
205205

206206
# multiclass
207207
rng = StableRNG(554)
@@ -220,7 +220,48 @@ end
220220
],
221221
pool=missing
222222
)
223-
@test_throws DomainError mode.(unf_arr)
223+
@test_throws ERR_NAN_FOUND mode.(unf_arr)
224+
end
225+
226+
@testset "broadcasting modes" begin
227+
# binary
228+
rng = StableRNG(668)
229+
probs = rand(rng, n)
230+
u = UnivariateFinite(probs, augment = true, pool=missing)
231+
supp = Distributions.support(u)
232+
_modes = modes.(u)
233+
@test _modes isa Vector{<:CategoricalArray}
234+
expected = [ifelse(p > 0.5, [supp[2]], [supp[1]]) for p in probs]
235+
@test all(_modes .== expected)
236+
237+
# multiclass, bimodal
238+
rng = StableRNG(554)
239+
P = rand(rng, n, c)
240+
M, M_idx = findmax(P, dims=2)
241+
M_idx = getindex.(M_idx, 2)
242+
for i in axes(P,1)
243+
m = M[i]
244+
j = M_idx[i]
245+
while j == M_idx[i]
246+
j = rand(axes(P,2))
247+
end
248+
P[i,j] = m
249+
end
250+
P ./= sum(P, dims=2)
251+
u = UnivariateFinite(P, pool=missing)
252+
expected = modes.([u...])
253+
@test all(modes.(u) .== expected)
254+
255+
# `mode` broadcasting of `Univariate` objects containing `NaN` in probs.
256+
unf_arr = UnivariateFinite(
257+
[
258+
0.1 0.2 NaN 0.1 NaN;
259+
0.2 0.1 0.1 0.4 0.2;
260+
0.3 NaN 0.2 NaN 0.3
261+
],
262+
pool=missing
263+
)
264+
@test_throws ERR_NAN_FOUND modes.(unf_arr)
224265
end
225266

226267
@testset "cat for UnivariateFiniteArray" begin

test/methods.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Random
99
rng = StableRNG(123)
1010
using ScientificTypes
1111

12-
import CategoricalDistributions: classes
12+
import CategoricalDistributions: classes, ERR_NAN_FOUND
1313

1414
v = categorical(collect("asqfasqffqsaaaa"), ordered=true)
1515
V = categorical(collect("asqfasqffqsaaaa"))
@@ -19,8 +19,10 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
1919
@testset "set 1" begin
2020

2121
# ordered (OrderedFactor)
22-
dict = Dict(s=>0.1, q=> 0.2, f=> 0.7)
22+
dict = Dict(s=>0.1, q=>0.2, f=>0.7)
2323
d = UnivariateFinite(dict)
24+
dict_bimodal = Dict(a=>0.1, s=>0.1, q=>0.4, f=>0.4)
25+
d_bimodal = UnivariateFinite(dict_bimodal)
2426
@test classes(d) == [a, f, q, s]
2527
@test classes(d) == classes(s)
2628
@test levels(d) == levels(s)
@@ -45,6 +47,7 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
4547
@test logpdf(d, 'f') log(0.7)
4648
@test isinf(logpdf(d, a))
4749
@test mode(d) == f
50+
@test modes(d_bimodal) == [f, q]
4851

4952
@test UnivariateFinite(support(d), [0.7, 0.2, 0.1]) d
5053

@@ -72,7 +75,7 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
7275
@test isapprox(freq[q]/N, ffreq[q]/N)
7376

7477
# unordered (Multiclass):
75-
dict = Dict(S=>0.1, Q=> 0.2, F=> 0.7)
78+
dict = Dict(S=>0.1, Q=>0.2, F=>0.7)
7679
d = UnivariateFinite(dict)
7780
@test classes(d) == [a, f, q, s]
7881
@test classes(d) == classes(s)
@@ -178,7 +181,21 @@ end
178181

179182
# `mode` of `Univariate` objects containing `NaN` in probs.
180183
unf = UnivariateFinite([0.1, 0.2, NaN, 0.1, NaN], pool=missing)
181-
@test_throws DomainError mode(unf)
184+
@test_throws ERR_NAN_FOUND mode(unf)
185+
end
186+
187+
@testset "Univariate modes, bimodal" begin
188+
v = categorical(1:101)
189+
p = rand(rng,101)
190+
p[24] = 2*maximum(p)
191+
p[42] = p[24]
192+
p = p/sum(p)
193+
d = UnivariateFinite(v, p)
194+
@test modes(d) == [24, 42]
195+
196+
# `mode` of `Univariate` objects containing `NaN` in probs.
197+
unf = UnivariateFinite([0.1, 0.2, NaN, 0.1, NaN], pool=missing)
198+
@test_throws ERR_NAN_FOUND modes(unf)
182199
end
183200

184201
@testset "UnivariateFinite methods" begin

0 commit comments

Comments
 (0)