Skip to content

Commit f7093e6

Browse files
authored
Merge pull request #12 from JuliaAI/arithmetic
Add arithmetic for `UnivariateFinite` objects.
2 parents 9747a66 + a73e225 commit f7093e6

File tree

10 files changed

+235
-115
lines changed

10 files changed

+235
-115
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/CategoricalDistributions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ using Random
99
using UnicodePlots
1010

1111
const Dist = Distributions
12+
const MAX_NUM_LEVELS_TO_SHOW_BARS = 12
1213

1314
import Distributions: pdf, logpdf, support, mode
1415

1516
include("utilities.jl")
1617
include("types.jl")
1718
include("methods.jl")
1819
include("arrays.jl")
20+
include("arithmetic.jl")
1921

2022
export UnivariateFinite, UnivariateFiniteArray
2123

src/arithmetic.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# ## ARITHMETIC
2+
3+
const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
4+
"Adding two `UnivariateFinite` objects whose "*
5+
"sample spaces have different labellings is not allowed. ")
6+
7+
import Base: +, *, /, -
8+
9+
pdf_matrix(d::UnivariateFinite, L) = pdf.(d, L)
10+
pdf_matrix(d::AbstractArray{<:UnivariateFinite}, L) = pdf(d, L)
11+
12+
function +(d1::U, d2::U) where U <: SingletonOrArray
13+
L = classes(d1)
14+
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
15+
return UnivariateFinite(L, pdf_matrix(d1, L) + pdf_matrix(d2, L))
16+
end
17+
18+
function _minus(d, T)
19+
S = d.scitype
20+
decoder = d.decoder
21+
prob_given_ref = copy(d.prob_given_ref)
22+
for ref in keys(prob_given_ref)
23+
prob_given_ref[ref] = -prob_given_ref[ref]
24+
end
25+
return T(S, decoder, prob_given_ref)
26+
end
27+
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
28+
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)
29+
30+
function -(d1::U, d2::U) where U <: SingletonOrArray
31+
L = classes(d1)
32+
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
33+
return UnivariateFinite(L, pdf_matrix(d1, L) - pdf_matrix(d2, L))
34+
end
35+
36+
# It seems that the restriction `x::Number` below (applying only to the
37+
# array case) is unavoidable because of a method ambiguity with
38+
# `Base.*(::AbstractArray, ::Number)`.
39+
40+
function _times(d, x, T)
41+
S = d.scitype
42+
decoder = d.decoder
43+
prob_given_ref = copy(d.prob_given_ref)
44+
for ref in keys(prob_given_ref)
45+
prob_given_ref[ref] *= x
46+
end
47+
return T(d.scitype, decoder, prob_given_ref)
48+
end
49+
*(d::UnivariateFinite, x) = _times(d, x, UnivariateFinite)
50+
*(d::UnivariateFiniteArray, x::Number) = _times(d, x, UnivariateFiniteArray)
51+
52+
*(x, d::UnivariateFinite) = d*x
53+
*(x::Number, d::UnivariateFiniteArray) = d*x
54+
/(d::UnivariateFinite, x) = d*inv(x)
55+
/(d::UnivariateFiniteArray, x::Number) = d*inv(x)

src/arrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ function Base.Broadcast.broadcasted(::typeof(mode),
254254
return reshape(mode_flat, size(u))
255255
end
256256

257+
257258
## EXTENSION OF CLASSES TO ARRAYS OF UNIVARIATE FINITE
258259

259260
# We already have `classes(::UnivariateFininiteArray)
@@ -266,3 +267,5 @@ function classes(yhat::AbstractArray{<:Union{Missing,UnivariateFinite}})
266267
i === nothing && throw(ERR_EMPTY_UNIVARIATE_FINITE)
267268
return classes(yhat[i])
268269
end
270+
271+

src/methods.jl

Lines changed: 13 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,22 @@ function Base.show(stream::IO, d::UnivariateFinite)
7979
print(stream, "UnivariateFinite{$(d.scitype)}($arg_str)")
8080
end
8181

82+
Base.show(io::IO, mime::MIME"text/plain",
83+
d::UnivariateFinite) = show(io, d)
84+
85+
# in common case of `Real` probabilities we can do a pretty bar plot:
8286
function Base.show(io::IO, mime::MIME"text/plain",
83-
d::UnivariateFinite{S}) where S
87+
d::UnivariateFinite{<:Finite{K},V,R,P}) where {K,V,R,P<:Real}
88+
show_bars = false
89+
if K <= MAX_NUM_LEVELS_TO_SHOW_BARS &&
90+
all(>=(0), values(d.prob_given_ref))
91+
show_bars = true
92+
end
93+
show_bars || return show(io, d)
8494
s = support(d)
8595
x = string.(CategoricalArrays.DataAPI.unwrap.(s))
8696
y = pdf.(d, s)
97+
S = d.scitype
8798
plt = barplot(x, y, title="UnivariateFinite{$S}")
8899
show(io, mime, plt)
89100
end
@@ -125,58 +136,6 @@ end
125136

126137
# TODO: It would be useful to define == as well.
127138

128-
# TODO: Now that `UnivariateFinite` is any finite measure, we can
129-
# replace the following nonsense with an overloading of `+`. I think
130-
# it is only used in MLJEnsembles.jl - but need to check. I believe
131-
# this is a private method we can easily remove
132-
133-
function average(dvec::AbstractVector{UnivariateFinite{S,V,R,P}};
134-
weights=nothing) where {S,V,R,P}
135-
136-
n = length(dvec)
137-
138-
Dist.@check_args(UnivariateFinite, weights == nothing || n==length(weights))
139-
140-
# check all distributions have consistent pool:
141-
first_index = first(dvec).decoder.classes
142-
for d in dvec
143-
d.decoder.classes == first_index ||
144-
error("Averaging UnivariateFinite distributions with incompatible"*
145-
" pools. ")
146-
end
147-
148-
# get all refs:
149-
refs = reduce(union, [keys(d.prob_given_ref) for d in dvec]) |> collect
150-
151-
# initialize the prob dictionary for the distribution sum:
152-
prob_given_ref = LittleDict{R,P}([refs...], zeros(P, length(refs)))
153-
154-
# make vector of all the distributions dicts padded to have same common keys:
155-
prob_given_ref_vec = map(dvec) do d
156-
merge(prob_given_ref, d.prob_given_ref)
157-
end
158-
159-
# sum up:
160-
if weights == nothing
161-
scale = 1/n
162-
for x in refs
163-
for k in 1:n
164-
prob_given_ref[x] += scale*prob_given_ref_vec[k][x]
165-
end
166-
end
167-
else
168-
scale = 1/sum(weights)
169-
for x in refs
170-
for k in 1:n
171-
prob_given_ref[x] +=
172-
weights[k]*prob_given_ref_vec[k][x]*scale
173-
end
174-
end
175-
end
176-
d1 = first(dvec)
177-
return UnivariateFinite(sample_scitype(d1), d1.decoder, prob_given_ref)
178-
end
179-
180139
"""
181140
Dist.pdf(d::UnivariateFinite, x)
182141
@@ -374,6 +333,6 @@ end
374333
# # BROADCASTING OVER SINGLE UNIVARIATE FINITE
375334

376335
# This mirrors behaviour assigned Distributions.Distribution objects,
377-
# which allows `pdf.(d::UnivariateFinite, support(d))` to work.
336+
# which allows `pdf.(d::UnivariateFinite, support(d))` to work.
378337

379338
Broadcast.broadcastable(d::UnivariateFinite) = Ref(d)

src/types.jl

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@ choosing `probs` to be an array of one higher dimension than the array
1414
generated.
1515
1616
Here the word "probabilities" is an abuse of terminology as there is
17-
no requirement that probabilities actually sum to one, only that they
18-
be non-negative. So `UnivariateFinite` objects actually implement
19-
arbitrary non-negative measures over finite sets of labelled points. A
20-
`UnivariateDistribution` will be a bona fide probability measure when
21-
constructed using the `augment=true` option (see below) or when
22-
`fit` to data.
17+
no requirement that the that probabilities actually sum to one. The
18+
only requirement is that the probabilities have a common type `T` for
19+
which `zero(T)` is defined. In particular, `UnivariateFinite` objects
20+
implement arbitrary non-negative, signed, or complex measures over
21+
finite sets of labelled points. A `UnivariateDistribution` will be a
22+
bona fide probability measure when constructed using the
23+
`augment=true` option (see below) or when `fit` to data. And the
24+
probabilities of a `UnivariateFinite` object `d` must be non-negative,
25+
with a non-zero sum, for `rand(d)` to be defined and interpretable.
2326
2427
Unless `pool` is specified, `support` should have type
2528
`AbstractVector{<:CategoricalValue}` and all elements are assumed to
@@ -37,28 +40,37 @@ constructor then returns an array of `UnivariateFinite` distributions
3740
of size `(n1, n2, ..., nk)`.
3841
3942
```
40-
using CategoricalArrays
41-
v = categorical([:x, :x, :y, :x, :z])
42-
43-
julia> UnivariateFinite(classes(v), [0.2, 0.3, 0.5])
44-
UnivariateFinite{Multiclass{3}}(x=>0.2, y=>0.3, z=>0.5)
45-
46-
julia> d = UnivariateFinite([v[1], v[end]], [0.1, 0.9])
43+
using CategoricalDistributions, CategoricalArrays, Distributions
44+
samples = categorical(['x', 'x', 'y', 'x', 'z'])
45+
julia> Distributions.fit(UnivariateFinite, samples)
46+
UnivariateFinite{Multiclass{3}}
47+
┌ ┐
48+
x ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.6
49+
y ┤■■■■■■■■■■■■ 0.2
50+
z ┤■■■■■■■■■■■■ 0.2
51+
└ ┘
52+
53+
julia> d = UnivariateFinite([samples[1], samples[end]], [0.1, 0.9])
4754
UnivariateFinite{Multiclass{3}(x=>0.1, z=>0.9)
55+
UnivariateFinite{Multiclass{3}}
56+
┌ ┐
57+
x ┤■■■■ 0.1
58+
z ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.9
59+
└ ┘
4860
4961
julia> rand(d, 3)
5062
3-element Array{Any,1}:
51-
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
52-
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
53-
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
63+
CategoricalValue{Symbol,UInt32} 'z'
64+
CategoricalValue{Symbol,UInt32} 'z'
65+
CategoricalValue{Symbol,UInt32} 'z'
5466
55-
julia> levels(d)
67+
julia> levels(samples)
5668
3-element Array{Symbol,1}:
57-
:x
58-
:y
59-
:z
69+
'x'
70+
'y'
71+
'z'
6072
61-
julia> pdf(d, :y)
73+
julia> pdf(d, 'y')
6274
0.0
6375
```
6476
@@ -77,19 +89,27 @@ In the last case, specify `ordered=true` if the pool is to be
7789
considered ordered.
7890
7991
```
80-
julia> UnivariateFinite([:x, :z], [0.1, 0.9], pool=missing, ordered=true)
81-
UnivariateFinite{OrderedFactor{2}}(x=>0.1, z=>0.9)
82-
83-
julia> d = UnivariateFinite([:x, :z], [0.1, 0.9], pool=v) # v defined above
84-
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)
85-
86-
julia> pdf(d, :y) # allowed as `:y in levels(v)`
92+
julia> UnivariateFinite(['x', 'z'], [0.1, 0.9], pool=missing, ordered=true)
93+
UnivariateFinite{OrderedFactor{2}}
94+
┌ ┐
95+
x ┤■■■■ 0.1
96+
z ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.9
97+
└ ┘
98+
99+
samples = categorical(['x', 'x', 'y', 'x', 'z'])
100+
julia> d = UnivariateFinite(['x', 'z'], [0.1, 0.9], pool=samples)
101+
┌ ┐
102+
x ┤■■■■ 0.1
103+
z ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.9
104+
└ ┘
105+
106+
julia> pdf(d, 'y') # allowed as `'y' in levels(samples)`
87107
0.0
88108
89-
v = categorical([:x, :x, :y, :x, :z, :w])
109+
v = categorical(['x', 'x', 'y', 'x', 'z', 'w'])
90110
probs = rand(100, 3)
91111
probs = probs ./ sum(probs, dims=2)
92-
julia> UnivariateFinite([:x, :y, :z], probs, pool=v)
112+
julia> d1 = UnivariateFinite(['x', 'y', 'z'], probs, pool=v)
93113
100-element UnivariateFiniteVector{Multiclass{4},Symbol,UInt32,Float64}:
94114
UnivariateFinite{Multiclass{4}}(x=>0.194, y=>0.3, z=>0.505)
95115
UnivariateFinite{Multiclass{4}}(x=>0.727, y=>0.234, z=>0.0391)
@@ -107,6 +127,18 @@ for the classes `c2, c3, ..., cn`. The class `c1` probabilities are
107127
chosen so that each `UnivariateFinite` distribution in the returned
108128
array is a bona fide probability distribution.
109129
130+
```julia
131+
julia> UnivariateFinite([0.1, 0.2, 0.3], augment=true, pool=missing)
132+
3-element UnivariateFiniteArray{Multiclass{2}, String, UInt8, Float64, 1}:
133+
UnivariateFinite{Multiclass{2}}(class_1=>0.9, class_2=>0.1)
134+
UnivariateFinite{Multiclass{2}}(class_1=>0.8, class_2=>0.2)
135+
UnivariateFinite{Multiclass{2}}(class_1=>0.7, class_2=>0.3)
136+
137+
d2 = UnivariateFinite(['x', 'y', 'z'], probs[:, 2:end], augment=true, pool=v)
138+
julia> pdf(d1, levels(v)) ≈ pdf(d2, levels(v))
139+
true
140+
```
141+
110142
---
111143
112144
UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
@@ -142,6 +174,8 @@ struct UnivariateFinite{S,V,R,P}
142174
prob_given_ref::LittleDict{R,P,Vector{R}, Vector{P}}
143175
end
144176

177+
@doc DOC_CONSTRUCTOR UnivariateFinite
178+
145179
"""
146180
UnivariateFiniteArray
147181
@@ -160,6 +194,10 @@ end
160194

161195
const UnivariateFiniteVector{S,V,R,P} = UnivariateFiniteArray{S,V,R,P,1}
162196

197+
# private:
198+
const SingletonOrArray{S,V,R,P} = Union{UnivariateFinite{S,V,R,P},
199+
UnivariateFiniteArray{S,V,R,P}}
200+
163201

164202
# # CHECKS AND ERROR MESSAGES
165203

0 commit comments

Comments
 (0)