Skip to content

Commit 608e0d9

Browse files
committed
fix show method, doc-string and implement arithmetic
1 parent 1c3f73d commit 608e0d9

File tree

4 files changed

+105
-6
lines changed

4 files changed

+105
-6
lines changed

src/CategoricalDistributions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Random
99
using UnicodePlots
1010

1111
const Dist = Distributions
12+
const MAX_NUM_LEVELS_TO_SHOW_BARS = 12
1213

1314
import Distributions: pdf, logpdf, support, mode
1415

src/methods.jl

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,22 @@ function Base.show(stream::IO, d::UnivariateFinite)
7979
print(stream, "UnivariateFinite{$(d.scitype)}($arg_str)")
8080
end
8181

82+
Base.show(io::IO, mime::MIME"text/plain",
83+
d::UnivariateFinite) = show(io, d)
84+
85+
# in common case of `Real` probabilities we can do a pretty bar plot:
8286
function Base.show(io::IO, mime::MIME"text/plain",
83-
d::UnivariateFinite{S}) where S
87+
d::UnivariateFinite{<:Finite{K},V,R,P}) where {K,V,R,P<:Real}
88+
show_bars = false
89+
if K <= MAX_NUM_LEVELS_TO_SHOW_BARS &&
90+
all(>=(0), values(d.prob_given_ref))
91+
show_bars = true
92+
end
93+
show_bars || return show(io, d)
8494
s = support(d)
8595
x = string.(CategoricalArrays.DataAPI.unwrap.(s))
8696
y = pdf.(d, s)
97+
S = d.scitype
8798
plt = barplot(x, y, title="UnivariateFinite{$S}")
8899
show(io, mime, plt)
89100
end
@@ -371,3 +382,59 @@ function Dist.fit(d::Type{<:UnivariateFinite},
371382
end
372383

373384

385+
# ## ARITHMETIC
386+
387+
const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
388+
"Adding two `UnivariateFinite` objects whose "*
389+
"sample spaces have different labellings is not allowed. ")
390+
391+
import Base: +, *, /
392+
393+
function +(d1::U, d2::U) where U <: UnivariateFinite
394+
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
395+
S = d1.scitype
396+
decoder = d1.decoder
397+
prob_given_ref = copy(d1.prob_given_ref)
398+
for ref in keys(prob_given_ref)
399+
prob_given_ref[ref] += d2.prob_given_ref[ref]
400+
end
401+
return UnivariateFinite(S, decoder, prob_given_ref)
402+
end
403+
404+
function -(d::UnivariateFinite)
405+
S = d.scitype
406+
decoder = d.decoder
407+
prob_given_ref = copy(d.prob_given_ref)
408+
for ref in keys(prob_given_ref)
409+
prob_given_ref[ref] = -prob_given_ref[ref]
410+
end
411+
return UnivariateFinite(S, decoder, prob_given_ref)
412+
end
413+
414+
function -(d1::U, d2::U) where U <: UnivariateFinite
415+
classes(d1) == classes(d2) || throw(ERR_DIFFERENT_SAMPLE_SPACES)
416+
S = d1.scitype
417+
decoder = d1.decoder
418+
prob_given_ref = copy(d1.prob_given_ref)
419+
for ref in keys(prob_given_ref)
420+
prob_given_ref[ref] -= d2.prob_given_ref[ref]
421+
end
422+
return UnivariateFinite(S, decoder, prob_given_ref)
423+
end
424+
425+
# TODO: remove type restrction on `x` in the following methods if
426+
# https://github.com/JuliaStats/Distributions.jl/issues/1438 is
427+
# resolved. Currently we'd have a method ambiguity
428+
429+
function *(d::UnivariateFinite, x::Real)
430+
S = d.scitype
431+
decoder = d.decoder
432+
prob_given_ref = copy(d.prob_given_ref)
433+
for ref in keys(prob_given_ref)
434+
prob_given_ref[ref] *= x
435+
end
436+
return UnivariateFinite(d.scitype, decoder, prob_given_ref)
437+
end
438+
*(x::Real, d::UnivariateFinite) = d*x
439+
440+
/(d::UnivariateFinite, x::Real) = d*inv(x)

src/types.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@ choosing `probs` to be an array of one higher dimension than the array
1414
generated.
1515
1616
Here the word "probabilities" is an abuse of terminology as there is
17-
no requirement that probabilities actually sum to one, only that they
18-
be non-negative. So `UnivariateFinite` objects actually implement
19-
arbitrary non-negative measures over finite sets of labelled points. A
17+
no requirement that the that probabilities actually sum to one. Indeed
18+
there is no restriction on the probablities at all. In particular,
19+
`UnivariateFinite` objects implement arbitrary non-negative, signed,
20+
or complex measures over finite sets of labelled points. A
2021
`UnivariateDistribution` will be a bona fide probability measure when
21-
constructed using the `augment=true` option (see below) or when
22-
`fit` to data.
22+
constructed using the `augment=true` option (see below) or when `fit`
23+
to data. And the probabilities of a `UnivariateFinite` object `d` must
24+
be non-negative, with a non-zero sum, for `rand(d)` to be defined and
25+
interpretable.
2326
2427
Unless `pool` is specified, `support` should have type
2528
`AbstractVector{<:CategoricalValue}` and all elements are assumed to
@@ -144,12 +147,15 @@ const _UnivariateFinite_{S} =
144147

145148
# Note that the keys of `prob_given_ref` need not exhaust all the
146149
# refs of all classes but will be ordered (LittleDicts preserve order)
150+
DOC_CONSTRUCTOR
147151
struct UnivariateFinite{S,V,R,P} <: _UnivariateFinite_{S}
148152
scitype::Type{S}
149153
decoder::CategoricalDecoder{V,R}
150154
prob_given_ref::LittleDict{R,P,Vector{R}, Vector{P}}
151155
end
152156

157+
@doc DOC_CONSTRUCTOR UnivariateFinite
158+
153159
"""
154160
UnivariateFiniteArray
155161

test/methods.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,31 @@ end
277277
# @test v ≈ v_close
278278
end
279279

280+
@tesset "arithmetic" begin
281+
L = ["yes", "no"]
282+
d1 = UnivariateFinite(L, rand(rng, 2), pool=missing)
283+
d2 = UnivariateFinite(L, rand(rng, 2), pool=missing)
284+
285+
# addition and subtraction:
286+
for op in [:+, :-]
287+
quote
288+
s = $op(d1, d2 )
289+
@test $op(pdf(d1, L), pdf(d2, L)) pdf(s, L)
290+
end |> eval
291+
end
292+
293+
# negative:
294+
d_neg = -d1
295+
@test pdf(d_neg, L) == -pdf(d1, L)
296+
297+
# multiplication by scalar:
298+
d3 = d1/42
299+
@test pdf(d3, L) pdf(d1, L)/42
300+
301+
# division by scalar:
302+
d3 = d1/42
303+
@test pdf(d3, L) pdf(d1, L)/42
304+
end
280305

281306
end # module
282307

0 commit comments

Comments
 (0)