Skip to content

Commit 46362b7

Browse files
authored
Merge pull request #53 from JuliaAI/dev
For a 0.1.11 release
2 parents 5f2660c + e3a7f6c commit 46362b7

File tree

13 files changed

+287
-107
lines changed

13 files changed

+287
-107
lines changed

.github/workflows/TagBot.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,22 @@ on:
44
types:
55
- created
66
workflow_dispatch:
7+
inputs:
8+
lookback:
9+
default: 3
10+
permissions:
11+
actions: read
12+
checks: read
13+
contents: write
14+
deployments: read
15+
issues: read
16+
discussions: read
17+
packages: read
18+
pages: read
19+
pull-requests: read
20+
repository-projects: read
21+
security-events: read
22+
statuses: read
723
jobs:
824
TagBot:
925
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.3'
2120
- '1.6'
2221
- '1' # automatically expands to the latest stable 1.x release of Julia.
22+
- '~1.9.0-0'
2323
os:
2424
- ubuntu-latest
2525
arch:
2626
- x64
27+
include:
28+
- os: windows-latest
29+
version: '1'
30+
arch: x86
2731
steps:
2832
- uses: actions/checkout@v2
2933
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.9"
4+
version = "0.1.11"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -12,20 +12,27 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1212
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
1313
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
1414

15+
[weakdeps]
16+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
17+
18+
[extensions]
19+
UnivariateFiniteDisplayExt = "UnicodePlots"
20+
1521
[compat]
1622
CategoricalArrays = "0.9, 0.10"
1723
Distributions = "0.25"
1824
Missings = "0.4, 1"
1925
OrderedCollections = "1.1"
2026
ScientificTypes = "3.0"
2127
UnicodePlots = "2, 3"
22-
julia = "1.3"
28+
julia = "1.6"
2329

2430
[extras]
2531
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2632
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2733
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2834
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
35+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2936

3037
[targets]
31-
test = ["Random", "StableRNGs", "Test", "FillArrays"]
38+
test = ["FillArrays", "Random", "StableRNGs", "Test", "UnicodePlots"]

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ this package is the class pool of a `CategoricalArray`:
3333
using CategoricalDistributions
3434
using CategoricalArrays
3535
import Distributions
36+
import UnicodePlots # for optional pretty display
3637
data = ["no", "yes", "no", "maybe", "maybe", "no",
3738
"maybe", "no", "maybe"] |> categorical
3839
julia> d = Distributions.fit(UnivariateFinite, data)

ext/UnivariateFiniteDisplayExt.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
module UnivariateFiniteDisplayExt
2+
3+
const MAX_NUM_LEVELS_TO_SHOW_BARS = 12
4+
5+
using CategoricalDistributions
6+
import CategoricalArrays
7+
import UnicodePlots
8+
import ScientificTypes.Finite
9+
10+
# The following is a specialization of a `show` method already in /src/ for the common
11+
# case of `Real` probabilities.
12+
function Base.show(io::IO, mime::MIME"text/plain",
13+
d::UnivariateFinite{<:Finite{K},V,R,P}) where {K,V,R,P<:Real}
14+
show_bars = false
15+
if K <= MAX_NUM_LEVELS_TO_SHOW_BARS &&
16+
all(>=(0), values(d.prob_given_ref))
17+
show_bars = true
18+
end
19+
show_bars || return show(io, d)
20+
s = support(d)
21+
x = string.(CategoricalArrays.DataAPI.unwrap.(s))
22+
y = pdf.(d, s)
23+
S = d.scitype
24+
plt = UnicodePlots.barplot(x, y, title="UnivariateFinite{$S}")
25+
show(io, mime, plt)
26+
end
27+
28+
end

src/CategoricalDistributions.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@ using OrderedCollections
1313
using CategoricalArrays
1414
import Missings
1515
using Random
16-
using UnicodePlots
1716

1817
const Dist = Distributions
19-
const MAX_NUM_LEVELS_TO_SHOW_BARS = 12
2018

2119
import Distributions: pdf, logpdf, support, mode
2220

@@ -35,4 +33,9 @@ export pdf, logpdf, support, mode
3533
# re-export from ScientificTypesBase:
3634
export Multiclass, OrderedFactor
3735

36+
# for julia < 1.9
37+
if !isdefined(Base, :get_extension)
38+
include("../ext/UnivariateFiniteDisplayExt.jl")
39+
end
40+
3841
end

src/arithmetic.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ## ARITHMETIC
22

33
const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
4-
"Adding two `UnivariateFinite` objects whose "*
4+
"Adding two `UnivariateFinite` objects whose " *
55
"sample spaces have different labellings is not allowed. ")
66

77
import Base: +, *, /, -
@@ -33,10 +33,6 @@ function -(d1::U, d2::U) where U <: SingletonOrArray
3333
return UnivariateFinite(L, pdf_matrix(d1, L) - pdf_matrix(d2, L))
3434
end
3535

36-
# It seems that the restriction `x::Number` below (applying only to the
37-
# array case) is unavoidable because of a method ambiguity with
38-
# `Base.*(::AbstractArray, ::Number)`.
39-
4036
function _times(d, x, T)
4137
S = d.scitype
4238
decoder = d.decoder
@@ -46,10 +42,10 @@ function _times(d, x, T)
4642
end
4743
return T(d.scitype, decoder, prob_given_ref)
4844
end
49-
*(d::UnivariateFinite, x) = _times(d, x, UnivariateFinite)
45+
*(d::UnivariateFinite, x::Number) = _times(d, x, UnivariateFinite)
5046
*(d::UnivariateFiniteArray, x::Number) = _times(d, x, UnivariateFiniteArray)
5147

52-
*(x, d::UnivariateFinite) = d*x
48+
*(x::Number, d::UnivariateFinite) = d*x
5349
*(x::Number, d::UnivariateFiniteArray) = d*x
54-
/(d::UnivariateFinite, x) = d*inv(x)
50+
/(d::UnivariateFinite, x::Number) = d*inv(x)
5551
/(d::UnivariateFiniteArray, x::Number) = d*inv(x)

src/arrays.jl

Lines changed: 87 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,54 @@ const UniFinArr = UnivariateFiniteArray
44

55
Base.size(u::UniFinArr, args...) =
66
size(first(values(u.prob_given_ref)), args...)
7-
8-
function Base.getindex(u::UniFinArr{<:Any,<:Any,R,P,N},
9-
i::Integer...) where {R,P,N}
10-
prob_given_ref = LittleDict{R,P}()
11-
for ref in keys(u.prob_given_ref)
12-
prob_given_ref[ref] = getindex(u.prob_given_ref[ref], i...)
7+
8+
function Base.getindex(u::UniFinArr{<:Any,<:Any,R, P}, i...) where {R, P}
9+
# It's faster to generate `Array`s of `refs` and indexed `ref_probs`
10+
# and pass them to the `LittleDict` constructor.
11+
# The first element of `u.prob_given_ref` is used to get the dimensions
12+
# for allocating these arrays.
13+
u_dict = u.prob_given_ref
14+
a, rest = Iterators.peel(u_dict)
15+
# `a` is of the form `key => value`.
16+
a_ref, a_prob = first(a), getindex(last(a), i...)
17+
18+
# Preallocate Arrays using the key and value of the first
19+
# element (i.e `a`) of `u_dict`.
20+
n_refs = length(u_dict)
21+
refs = Vector{R}(undef, n_refs)
22+
if a_prob isa AbstractArray
23+
ref_probs = Vector{Array{P, ndims(a_prob)}}(undef, n_refs)
24+
unf_constructor = UniFinArr
25+
else
26+
ref_probs = Vector{P}(undef, n_refs)
27+
unf_constructor = UnivariateFinite
28+
end
29+
30+
# Fill in the first elements
31+
# Both `refs` and `ref_probs` are both of type `Vector` and hence support
32+
# linear indexing with index starting at `1`
33+
refs[1] = a_ref
34+
ref_probs[1] = a_prob
35+
36+
# Fill in the rest
37+
iter = 2
38+
for (ref, ref_prob) in rest
39+
refs[iter] = ref
40+
ref_probs[iter] = getindex(ref_prob, i...)
41+
iter += 1
1342
end
14-
return UnivariateFinite(u.scitype, u.decoder, prob_given_ref)
43+
44+
# `keytype(prob_given_ref)` is always same as `keytype(u_dict)`.
45+
# But `ndims(valtype(prob_given_ref))` might not be the same
46+
# as `ndims(valtype(u_dict))`.
47+
prob_given_ref = LittleDict{R, eltype(ref_probs)}(refs, ref_probs)
48+
49+
return unf_constructor(u.scitype, u.decoder, prob_given_ref)
1550
end
1651

17-
function Base.getindex(u::UniFinArr{<:Any,<:Any,R,P,N},
18-
I...) where {R,P,N}
19-
prob_given_ref = LittleDict{R,Array{P,N}}()
20-
for ref in keys(u.prob_given_ref)
21-
prob_given_ref[ref] = getindex(u.prob_given_ref[ref], I...)
22-
end
23-
return UniFinArr(u.scitype, u.decoder, prob_given_ref)
52+
function Base.getindex(u::UniFinArr, idx::CartesianIndex)
53+
checkbounds(u, idx)
54+
return u[Tuple(idx)...]
2455
end
2556

2657
function Base.setindex!(u::UniFinArr{S,V,R,P,N},
@@ -35,9 +66,9 @@ end
3566
# TODO: return an exception without throwing it:
3667

3768
_err_incompatible_levels() = throw(DomainError(
38-
"Cannot concatenate `UnivariateFiniteArray`s with "*
39-
"different categorical levels (classes), "*
40-
"or whose levels, when ordered, are not "*
69+
"Cannot concatenate `UnivariateFiniteArray`s with " *
70+
"different categorical levels (classes), " *
71+
"or whose levels, when ordered, are not " *
4172
"consistently ordered. "))
4273

4374
# terminology:
@@ -61,14 +92,12 @@ function Base.cat(us::UniFinArr{S,V,R,P,N}...;
6192
for i in 2:length(us)
6293
isordered(us[i]) == ordered || _err_incompatible_levels()
6394
if ordered
64-
classes(us[i]) ==
65-
_classes|| _err_incompatible_levels()
95+
classes(us[i]) == _classes || _err_incompatible_levels()
6696
else
67-
Set(classes(us[i])) ==
68-
Set(_classes) || _err_incompatible_levels()
97+
Set(classes(us[i])) == Set(_classes) || _err_incompatible_levels()
6998
end
70-
support_with_duplicates =
71-
vcat(support_with_duplicates, Dist.support(us[i]))
99+
support_with_duplicates = vcat(support_with_duplicates,
100+
Dist.support(us[i]))
72101
end
73102
_support = unique(support_with_duplicates) # no-longer categorical!
74103

@@ -99,14 +128,12 @@ for func in [:pdf, :logpdf]
99128
eval(quote
100129
function Distributions.$func(
101130
u::AbstractArray{UnivariateFinite{S,V,R,P},N},
102-
C::AbstractVector{<:Union{
103-
V,
104-
CategoricalValue{V,R}}}) where {S,V,R,P,N}
131+
C::AbstractVector) where {S,V,R,P,N}
105132

106-
#ret = Array{P,N+1}(undef, size(u)..., length(C))
107133
ret = zeros(P, size(u)..., length(C))
108-
for i in eachindex(C)
109-
ret[fill(:,N)...,i] .= broadcast($func, u, C[i])
134+
# note that we do not require C to use 1-base indexing
135+
for (i, c) in enumerate(C)
136+
ret[fill(:,N)..., i] .= broadcast($func, u, c)
110137
end
111138
return ret
112139
end
@@ -124,30 +151,35 @@ end
124151

125152
# dummy function
126153
# returns `x[i]` for `Array` inputs `x`
127-
# For non-Array inputs returns `zero(dtype)`
154+
# For non-Array inputs returns the input
128155
#This avoids using an if statement
129-
_getindex(x::Array,i, dtype)=x[i]
130-
_getindex(::Nothing, i, dtype) = zero(dtype)
156+
_getindex(x::Array, i) = x[i]
157+
_getindex(x, i) = x
131158

132159
# pdf.(u, cv)
133160
function Base.Broadcast.broadcasted(
134161
::typeof(pdf),
135162
u::UniFinArr{S,V,R,P,N},
136163
cv::CategoricalValue) where {S,V,R,P,N}
137164

138-
cv in classes(u) || throw(err_missing_class(cv))
165+
# we assume that we compare categorical values by their unwrapped value
166+
# and pick the index of this value from classes(u)
167+
_classes = classes(u)
168+
cv_loc = get(CategoricalArrays.pool(_classes), cv, zero(R))
169+
isequal(cv_loc, 0) && throw(err_missing_class(cv))
139170

140171
f() = zeros(P, size(u)) #default caller function
141172

142173
return Base.Broadcast.Broadcasted(
143174
identity,
144-
(get(f, u.prob_given_ref, int(cv)),)
145-
)
175+
(get(f, u.prob_given_ref, cv_loc),)
176+
)
146177
end
178+
147179
Base.Broadcast.broadcasted(
148180
::typeof(pdf),
149181
u::UniFinArr{S,V,R,P,N},
150-
::Missing) where {S,V,R,P,N} = Missings.missings(P, length(u))
182+
::Missing) where {S,V,R,P,N} = Missings.missings(P, size(u))
151183

152184
# pdf.(u, v)
153185
function Base.Broadcast.broadcasted(
@@ -160,17 +192,23 @@ function Base.Broadcast.broadcasted(
160192
length(u) == length(v) ||throw(DimensionMismatch(
161193
"Arrays could not be broadcast to a common size; "*
162194
"got a dimension with lengths $(length(u)) and $(length(v))"))
163-
for cv in v
164-
ismissing(cv) || cv in classes(u) || throw(err_missing_class(cv))
195+
196+
_classes = classes(u)
197+
_classes_pool = CategoricalArrays.pool(_classes)
198+
T = eltype(v) >: Missing ? Missing : Union{}
199+
v_loc_flat = Vector{Tuple{Union{R, T}, Int}}(undef, length(v))
200+
201+
202+
for (i, x) in enumerate(v)
203+
cv_ref = ismissing(x) ? missing : get(_classes_pool, x, zero(R))
204+
isequal(cv_ref, 0) && throw(err_missing_class(x))
205+
v_loc_flat[i] = (cv_ref, i)
165206
end
166207

167-
# will use linear indexing:
168-
v_flat = ((v[i], i) for i in 1:length(v))
169-
170-
getter((cv, i), dtype) =
171-
_getindex(get(u.prob_given_ref, int(cv), nothing), i, dtype)
172-
getter(::Tuple{Missing,Any}, dtype) = missing
173-
ret_flat = getter.(v_flat, P)
208+
getter((cv_ref, i)) =
209+
_getindex(get(u.prob_given_ref, cv_ref, zero(P)), i)
210+
getter(::Tuple{Missing,Any}) = missing
211+
ret_flat = getter.(v_loc_flat)
174212
return reshape(ret_flat, size(u))
175213
end
176214

@@ -243,10 +281,10 @@ function Base.Broadcast.broadcasted(::typeof(mode),
243281
mode_flat = map(1:length(u)) do i
244282
max_prob = maximum(dic[ref][i] for ref in keys(dic))
245283
m = zero(R)
246-
247-
# `maximum` of any iterable containing `NaN` would return `NaN`
284+
285+
# `maximum` of any iterable containing `NaN` would return `NaN`
248286
# For this case the index `m` won't be updated in the loop as relations
249-
# involving NaN as one of it's argument always returns false
287+
# involving NaN as one of it's argument always returns false
250288
# (e.g `==(NaN, NaN)` returns false)
251289
throw_nan_error_if_needed(max_prob)
252290
for ref in keys(dic)
@@ -269,9 +307,7 @@ const ERR_EMPTY_UNIVARIATE_FINITE = ArgumentError(
269307
"No `UnivariateFinite` object found from which to extract classes. ")
270308

271309
function classes(yhat::AbstractArray{<:Union{Missing,UnivariateFinite}})
272-
i = findfirst(x->!ismissing(x), yhat)
310+
i = findfirst(!ismissing, yhat)
273311
i === nothing && throw(ERR_EMPTY_UNIVARIATE_FINITE)
274312
return classes(yhat[i])
275313
end
276-
277-

0 commit comments

Comments
 (0)