Skip to content

Commit be4bb23

Browse files
authored
Merge pull request #14 from JuliaAI/dev
For a 0.1.2 release
2 parents 74f7466 + 149afd3 commit be4bb23

File tree

12 files changed

+253
-131
lines changed

12 files changed

+253
-131
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.0'
20+
- '1.3'
2121
- '1' # automatically expands to the latest stable 1.x release of Julia.
2222
os:
2323
- ubuntu-latest

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -19,7 +19,7 @@ Missings = "0.4, 1"
1919
OrderedCollections = "1.1"
2020
ScientificTypesBase = "2"
2121
UnicodePlots = "2"
22-
julia = "1.0"
22+
julia = "1.3"
2323

2424
[extras]
2525
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ levels(d2)
7373
julia> pdf(d2, "maybe")
7474
0.0
7575

76-
julia> pdf(d2, "okay")https://github.com/JuliaAI/CategoricalDistributions.jl#measures-over-finite-labeled-sets
76+
julia> pdf(d2, "okay")
7777
ERROR: DomainError with Value okay not in pool. :
7878
```
7979

@@ -122,10 +122,10 @@ julia> pdf(v, L)
122122
## Measures over finite labeled sets
123123

124124
There is, in fact, no enforcement that probabilities in a
125-
`UnivariateFinite` distribution sum to one, only that they be
126-
non-negative. Thus `UnivariateFinite` objects can be more properly
127-
understood as an implementation of arbitrary non-negative measures
128-
over finite labeled sets.
125+
`UnivariateFinite` distribution sum to one, only that they be belong
126+
to a type `T` for which `zero(T)` is defined. In particular
127+
`UnivariateFinite` objects implement arbitrary non-negative, signed,
128+
or complex measures over a finite labeled set.
129129

130130

131131
## What does this package provide?

src/CategoricalDistributions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ using Random
99
using UnicodePlots
1010

1111
const Dist = Distributions
12+
const MAX_NUM_LEVELS_TO_SHOW_BARS = 12
1213

1314
import Distributions: pdf, logpdf, support, mode
1415

1516
include("utilities.jl")
1617
include("types.jl")
1718
include("methods.jl")
1819
include("arrays.jl")
20+
include("arithmetic.jl")
1921

2022
export UnivariateFinite, UnivariateFiniteArray
2123

src/arithmetic.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# ## ARITHMETIC
2+
3+
const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
4+
"Adding two `UnivariateFinite` objects whose "*
5+
"sample spaces have different labellings is not allowed. ")
6+
7+
import Base: +, *, /, -
8+
9+
pdf_matrix(d::UnivariateFinite, L) = pdf.(d, L)
10+
pdf_matrix(d::AbstractArray{<:UnivariateFinite}, L) = pdf(d, L)
11+
12+
function +(d1::U, d2::U) where U <: SingletonOrArray
13+
L = classes(d1)
14+
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
15+
return UnivariateFinite(L, pdf_matrix(d1, L) + pdf_matrix(d2, L))
16+
end
17+
18+
function _minus(d, T)
19+
S = d.scitype
20+
decoder = d.decoder
21+
prob_given_ref = copy(d.prob_given_ref)
22+
for ref in keys(prob_given_ref)
23+
prob_given_ref[ref] = -prob_given_ref[ref]
24+
end
25+
return T(S, decoder, prob_given_ref)
26+
end
27+
-(d::UnivariateFinite) = _minus(d, UnivariateFinite)
28+
-(d::UnivariateFiniteArray) = _minus(d, UnivariateFiniteArray)
29+
30+
function -(d1::U, d2::U) where U <: SingletonOrArray
31+
L = classes(d1)
32+
L == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
33+
return UnivariateFinite(L, pdf_matrix(d1, L) - pdf_matrix(d2, L))
34+
end
35+
36+
# It seems that the restriction `x::Number` below (applying only to the
37+
# array case) is unavoidable because of a method ambiguity with
38+
# `Base.*(::AbstractArray, ::Number)`.
39+
40+
function _times(d, x, T)
41+
S = d.scitype
42+
decoder = d.decoder
43+
prob_given_ref = copy(d.prob_given_ref)
44+
for ref in keys(prob_given_ref)
45+
prob_given_ref[ref] *= x
46+
end
47+
return T(d.scitype, decoder, prob_given_ref)
48+
end
49+
*(d::UnivariateFinite, x) = _times(d, x, UnivariateFinite)
50+
*(d::UnivariateFiniteArray, x::Number) = _times(d, x, UnivariateFiniteArray)
51+
52+
*(x, d::UnivariateFinite) = d*x
53+
*(x::Number, d::UnivariateFiniteArray) = d*x
54+
/(d::UnivariateFinite, x) = d*inv(x)
55+
/(d::UnivariateFiniteArray, x::Number) = d*inv(x)

src/arrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ function Base.Broadcast.broadcasted(::typeof(mode),
254254
return reshape(mode_flat, size(u))
255255
end
256256

257+
257258
## EXTENSION OF CLASSES TO ARRAYS OF UNIVARIATE FINITE
258259

259260
# We already have `classes(::UnivariateFininiteArray)
@@ -266,3 +267,5 @@ function classes(yhat::AbstractArray{<:Union{Missing,UnivariateFinite}})
266267
i === nothing && throw(ERR_EMPTY_UNIVARIATE_FINITE)
267268
return classes(yhat[i])
268269
end
270+
271+

src/methods.jl

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,22 @@ function Base.show(stream::IO, d::UnivariateFinite)
7979
print(stream, "UnivariateFinite{$(d.scitype)}($arg_str)")
8080
end
8181

82+
Base.show(io::IO, mime::MIME"text/plain",
83+
d::UnivariateFinite) = show(io, d)
84+
85+
# in common case of `Real` probabilities we can do a pretty bar plot:
8286
function Base.show(io::IO, mime::MIME"text/plain",
83-
d::UnivariateFinite{S}) where S
87+
d::UnivariateFinite{<:Finite{K},V,R,P}) where {K,V,R,P<:Real}
88+
show_bars = false
89+
if K <= MAX_NUM_LEVELS_TO_SHOW_BARS &&
90+
all(>=(0), values(d.prob_given_ref))
91+
show_bars = true
92+
end
93+
show_bars || return show(io, d)
8494
s = support(d)
8595
x = string.(CategoricalArrays.DataAPI.unwrap.(s))
8696
y = pdf.(d, s)
97+
S = d.scitype
8798
plt = barplot(x, y, title="UnivariateFinite{$S}")
8899
show(io, mime, plt)
89100
end
@@ -125,58 +136,6 @@ end
125136

126137
# TODO: It would be useful to define == as well.
127138

128-
# TODO: Now that `UnivariateFinite` is any finite measure, we can
129-
# replace the following nonsense with an overloading of `+`. I think
130-
# it is only used in MLJEnsembles.jl - but need to check. I believe
131-
# this is a private method we can easily remove
132-
133-
function average(dvec::AbstractVector{UnivariateFinite{S,V,R,P}};
134-
weights=nothing) where {S,V,R,P}
135-
136-
n = length(dvec)
137-
138-
Dist.@check_args(UnivariateFinite, weights == nothing || n==length(weights))
139-
140-
# check all distributions have consistent pool:
141-
first_index = first(dvec).decoder.classes
142-
for d in dvec
143-
d.decoder.classes == first_index ||
144-
error("Averaging UnivariateFinite distributions with incompatible"*
145-
" pools. ")
146-
end
147-
148-
# get all refs:
149-
refs = reduce(union, [keys(d.prob_given_ref) for d in dvec]) |> collect
150-
151-
# initialize the prob dictionary for the distribution sum:
152-
prob_given_ref = LittleDict{R,P}([refs...], zeros(P, length(refs)))
153-
154-
# make vector of all the distributions dicts padded to have same common keys:
155-
prob_given_ref_vec = map(dvec) do d
156-
merge(prob_given_ref, d.prob_given_ref)
157-
end
158-
159-
# sum up:
160-
if weights == nothing
161-
scale = 1/n
162-
for x in refs
163-
for k in 1:n
164-
prob_given_ref[x] += scale*prob_given_ref_vec[k][x]
165-
end
166-
end
167-
else
168-
scale = 1/sum(weights)
169-
for x in refs
170-
for k in 1:n
171-
prob_given_ref[x] +=
172-
weights[k]*prob_given_ref_vec[k][x]*scale
173-
end
174-
end
175-
end
176-
d1 = first(dvec)
177-
return UnivariateFinite(sample_scitype(d1), d1.decoder, prob_given_ref)
178-
end
179-
180139
"""
181140
Dist.pdf(d::UnivariateFinite, x)
182141
@@ -371,3 +330,9 @@ function Dist.fit(d::Type{<:UnivariateFinite},
371330
end
372331

373332

333+
# # BROADCASTING OVER SINGLE UNIVARIATE FINITE
334+
335+
# This mirrors behaviour assigned Distributions.Distribution objects,
336+
# which allows `pdf.(d::UnivariateFinite, support(d))` to work.
337+
338+
Broadcast.broadcastable(d::UnivariateFinite) = Ref(d)

0 commit comments

Comments
 (0)