Skip to content

Commit 308fd08

Browse files
committed
Implement modes
* Overloads Distributions.modes for UnivariateFinite * Defines efficient broadcasting of modes over UnivariateFiniteArray * Tests * README
1 parent 4be9db4 commit 308fd08

File tree

6 files changed

+105
-13
lines changed

6 files changed

+105
-13
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Arrays of `UnivariateFinite` distributions are defined using the same
8282
constructor. Broadcasting methods, such as `pdf`, are optimized for
8383
such arrays:
8484

85-
```
85+
```julia
8686
julia> v = UnivariateFinite(["no", "yes"], [0.1, 0.2, 0.3, 0.4], augment=true, pool=data)
8787
4-element UnivariateFiniteArray{Multiclass{3}, String, UInt32, Float64, 1}:
8888
UnivariateFinite{Multiclass{3}}(no=>0.9, yes=>0.1)
@@ -119,7 +119,6 @@ julia> pdf(v, L)
119119
0.0 0.6 0.4
120120
```
121121

122-
123122
## Measures over finite labeled sets
124123

125124
There is, in fact, no enforcement that probabilities in a
@@ -128,7 +127,6 @@ to a type `T` for which `zero(T)` is defined. In particular
128127
`UnivariateFinite` objects implement arbitrary non-negative, signed,
129128
or complex measures over a finite labeled set.
130129

131-
132130
## What does this package provide?
133131

134132
- A new type `UnivariateFinite{S}` for representing probability
@@ -144,7 +142,7 @@ or complex measures over a finite labeled set.
144142
- Implementations of `rand` for generating random samples of a
145143
`UnivariateFinite` distribution.
146144

147-
- Implementations of the `pdf`, `logpdf` and `mode` methods of
145+
- Implementations of the `pdf`, `logpdf`, `mode` and `modes` methods of
148146
Distributions.jl, with efficient broadcasting over the new array
149147
type.
150148

src/CategoricalDistributions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Random
1616

1717
const Dist = Distributions
1818

19-
import Distributions: pdf, logpdf, support, mode
19+
import Distributions: pdf, logpdf, support, mode, modes
2020

2121
include("utilities.jl")
2222
include("types.jl")
@@ -28,7 +28,7 @@ include("arithmetic.jl")
2828
export UnivariateFinite, UnivariateFiniteArray, UnivariateFiniteVector
2929

3030
# re-eport from Distributions:
31-
export pdf, logpdf, support, mode
31+
export pdf, logpdf, support, mode, modes
3232

3333
# re-export from ScientificTypesBase:
3434
export Multiclass, OrderedFactor

src/arrays.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ Base.Broadcast.broadcasted(
271271
c::Missing) where {S,V,R,P,N} = Missings.missings(P, length(u))
272272

273273

274-
## PERFORMANT BROADCASTING OF mode:
274+
## PERFORMANT BROADCASTING OF mode(s):
275275

276276
function Base.Broadcast.broadcasted(::typeof(mode),
277277
u::UniFinArr{S,V,R,P,N}) where {S,V,R,P,N}
@@ -298,6 +298,26 @@ function Base.Broadcast.broadcasted(::typeof(mode),
298298
return reshape(mode_flat, size(u))
299299
end
300300

301+
function Base.Broadcast.broadcasted(::typeof(modes),
302+
u::UniFinArr{S,V,R,P,N}) where {S,V,R,P,N}
303+
dic = u.prob_given_ref
304+
305+
# using linear indexing:
306+
mode_flat = map(1:length(u)) do i
307+
max_prob = maximum(dic[ref][i] for ref in keys(dic))
308+
M = R[]
309+
310+
# see comment for in broadcasted(::mode) above
311+
throw_nan_error_if_needed(max_prob)
312+
for ref in keys(dic)
313+
if dic[ref][i] == max_prob
314+
push!(M, ref)
315+
end
316+
end
317+
return u.decoder(M)
318+
end
319+
return reshape(mode_flat, size(u))
320+
end
301321

302322
## EXTENSION OF CLASSES TO ARRAYS OF UNIVARIATE FINITE
303323

src/methods.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,28 @@ function Dist.mode(d::UnivariateFinite)
178178
return d.decoder(m)
179179
end
180180

181+
function Dist.modes(d::UnivariateFinite{S,V,R,P}) where {S,V,R,P}
182+
dic = d.prob_given_ref
183+
p = values(dic)
184+
max_prob = maximum(p)
185+
M = R[] # modes
186+
187+
# see comment in `mode` above
188+
throw_nan_error_if_needed(max_prob)
189+
for (x, prob) in dic
190+
if prob == max_prob
191+
push!(M, x)
192+
end
193+
end
194+
return d.decoder(M)
195+
end
196+
181197
function throw_nan_error_if_needed(x)
182198
if isnan(x)
183199
throw(
184200
DomainError(
185201
NaN,
186-
"`mode` is invalid for `UnivariateFininite` distribution "*
202+
"`mode(s)` is invalid for a `UnivariateFinite` distribution "*
187203
"with `pdf` containing `NaN`s"
188204
)
189205
)

test/arrays.jl

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,10 @@ end
198198
probs = rand(rng, n)
199199
u = UnivariateFinite(probs, augment = true, pool=missing)
200200
supp = Distributions.support(u)
201-
modes = mode.(u)
202-
@test modes isa CategoricalArray
201+
_modes = mode.(u)
202+
@test _modes isa CategoricalArray
203203
expected = [ifelse(p > 0.5, supp[2], supp[1]) for p in probs]
204-
@test all(modes .== expected)
204+
@test all(_modes .== expected)
205205

206206
# multiclass
207207
rng = StableRNG(554)
@@ -223,6 +223,47 @@ end
223223
@test_throws DomainError mode.(unf_arr)
224224
end
225225

226+
@testset "broadcasting modes" begin
227+
# binary
228+
rng = StableRNG(668)
229+
probs = rand(rng, n)
230+
u = UnivariateFinite(probs, augment = true, pool=missing)
231+
supp = Distributions.support(u)
232+
_modes = modes.(u)
233+
@test _modes isa Vector{<:CategoricalArray}
234+
expected = [ifelse(p > 0.5, [supp[2]], [supp[1]]) for p in probs]
235+
@test all(_modes .== expected)
236+
237+
# multiclass, bimodal
238+
rng = StableRNG(554)
239+
P = rand(rng, n, c)
240+
M, M_idx = findmax(P, dims=2)
241+
M_idx = getindex.(M_idx, 2)
242+
for i in axes(P,1)
243+
m = M[i]
244+
j = M_idx[i]
245+
while j == M_idx[i]
246+
j = rand(axes(P,2))
247+
end
248+
P[i,j] = m
249+
end
250+
P ./= sum(P, dims=2)
251+
u = UnivariateFinite(P, pool=missing)
252+
expected = modes.([u...])
253+
@test all(modes.(u) .== expected)
254+
255+
# `mode` broadcasting of `Univariate` objects containing `NaN` in probs.
256+
unf_arr = UnivariateFinite(
257+
[
258+
0.1 0.2 NaN 0.1 NaN;
259+
0.2 0.1 0.1 0.4 0.2;
260+
0.3 NaN 0.2 NaN 0.3
261+
],
262+
pool=missing
263+
)
264+
@test_throws DomainError modes.(unf_arr)
265+
end
266+
226267
@testset "cat for UnivariateFiniteArray" begin
227268

228269
# ordered:

test/methods.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
1919
@testset "set 1" begin
2020

2121
# ordered (OrderedFactor)
22-
dict = Dict(s=>0.1, q=> 0.2, f=> 0.7)
22+
dict = Dict(s=>0.1, q=>0.2, f=>0.7)
2323
d = UnivariateFinite(dict)
24+
dict_bimodal = Dict(a=>0.1, s=>0.1, q=>0.4, f=>0.4)
25+
d_bimodal = UnivariateFinite(dict_bimodal)
2426
@test classes(d) == [a, f, q, s]
2527
@test classes(d) == classes(s)
2628
@test levels(d) == levels(s)
@@ -45,6 +47,7 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
4547
@test logpdf(d, 'f') log(0.7)
4648
@test isinf(logpdf(d, a))
4749
@test mode(d) == f
50+
@test modes(d_bimodal) == [f, q]
4851

4952
@test UnivariateFinite(support(d), [0.7, 0.2, 0.1]) d
5053

@@ -72,7 +75,7 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
7275
@test isapprox(freq[q]/N, ffreq[q]/N)
7376

7477
# unordered (Multiclass):
75-
dict = Dict(S=>0.1, Q=> 0.2, F=> 0.7)
78+
dict = Dict(S=>0.1, Q=>0.2, F=>0.7)
7679
d = UnivariateFinite(dict)
7780
@test classes(d) == [a, f, q, s]
7881
@test classes(d) == classes(s)
@@ -181,6 +184,20 @@ end
181184
@test_throws DomainError mode(unf)
182185
end
183186

187+
@testset "Univariate modes, bimodal" begin
188+
v = categorical(1:101)
189+
p = rand(rng,101)
190+
p[24] = 2*maximum(p)
191+
p[42] = p[24]
192+
p = p/sum(p)
193+
d = UnivariateFinite(v, p)
194+
@test modes(d) == [24, 42]
195+
196+
# `mode` of `Univariate` objects containing `NaN` in probs.
197+
unf = UnivariateFinite([0.1, 0.2, NaN, 0.1, NaN], pool=missing)
198+
@test_throws DomainError modes(unf)
199+
end
200+
184201
@testset "UnivariateFinite methods" begin
185202
y = categorical(["yes", "no", "yes", "yes", "maybe"])
186203
yes = y[1]

0 commit comments

Comments
 (0)