Skip to content

Commit e8e4aa6

Browse files
authored
Merge pull request #40 from bkamins/patch-1
Improve implementation of pdf
2 parents 3265ec6 + ceed57b commit e8e4aa6

File tree

8 files changed

+134
-58
lines changed

8 files changed

+134
-58
lines changed

.github/workflows/TagBot.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,22 @@ on:
44
types:
55
- created
66
workflow_dispatch:
7+
inputs:
8+
lookback:
9+
default: 3
10+
permissions:
11+
actions: read
12+
checks: read
13+
contents: write
14+
deployments: read
15+
issues: read
16+
discussions: read
17+
packages: read
18+
pages: read
19+
pull-requests: read
20+
repository-projects: read
21+
security-events: read
22+
statuses: read
723
jobs:
824
TagBot:
925
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.3'
2120
- '1.6'
2221
- '1' # automatically expands to the latest stable 1.x release of Julia.
22+
- 'nightly'
2323
os:
2424
- ubuntu-latest
2525
arch:
2626
- x64
27+
include:
28+
- os: windows-latest
29+
version: '1'
30+
arch: x86
2731
steps:
2832
- uses: actions/checkout@v2
2933
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.9"
4+
version = "0.1.10"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -19,7 +19,7 @@ Missings = "0.4, 1"
1919
OrderedCollections = "1.1"
2020
ScientificTypes = "3.0"
2121
UnicodePlots = "2, 3"
22-
julia = "1.3"
22+
julia = "1.6"
2323

2424
[extras]
2525
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/arithmetic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ## ARITHMETIC
22

33
const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
4-
"Adding two `UnivariateFinite` objects whose "*
4+
"Adding two `UnivariateFinite` objects whose " *
55
"sample spaces have different labellings is not allowed. ")
66

77
import Base: +, *, /, -

src/arrays.jl

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ function Base.getindex(u::UniFinArr{<:Any,<:Any,R,P,N},
1414
return UnivariateFinite(u.scitype, u.decoder, prob_given_ref)
1515
end
1616

17+
function Base.getindex(u::UniFinArr, idx::CartesianIndex)
18+
checkbounds(u, idx)
19+
return u[Tuple(idx)...]
20+
end
21+
1722
function Base.getindex(u::UniFinArr{<:Any,<:Any,R,P,N},
1823
I...) where {R,P,N}
1924
prob_given_ref = LittleDict{R,Array{P,N}}()
@@ -35,9 +40,9 @@ end
3540
# TODO: return an exception without throwing it:
3641

3742
_err_incompatible_levels() = throw(DomainError(
38-
"Cannot concatenate `UnivariateFiniteArray`s with "*
39-
"different categorical levels (classes), "*
40-
"or whose levels, when ordered, are not "*
43+
"Cannot concatenate `UnivariateFiniteArray`s with " *
44+
"different categorical levels (classes), " *
45+
"or whose levels, when ordered, are not " *
4146
"consistently ordered. "))
4247

4348
# terminology:
@@ -61,14 +66,12 @@ function Base.cat(us::UniFinArr{S,V,R,P,N}...;
6166
for i in 2:length(us)
6267
isordered(us[i]) == ordered || _err_incompatible_levels()
6368
if ordered
64-
classes(us[i]) ==
65-
_classes|| _err_incompatible_levels()
69+
classes(us[i]) == _classes || _err_incompatible_levels()
6670
else
67-
Set(classes(us[i])) ==
68-
Set(_classes) || _err_incompatible_levels()
71+
Set(classes(us[i])) == Set(_classes) || _err_incompatible_levels()
6972
end
70-
support_with_duplicates =
71-
vcat(support_with_duplicates, Dist.support(us[i]))
73+
support_with_duplicates = vcat(support_with_duplicates,
74+
Dist.support(us[i]))
7275
end
7376
_support = unique(support_with_duplicates) # no-longer categorical!
7477

@@ -99,14 +102,12 @@ for func in [:pdf, :logpdf]
99102
eval(quote
100103
function Distributions.$func(
101104
u::AbstractArray{UnivariateFinite{S,V,R,P},N},
102-
C::AbstractVector{<:Union{
103-
V,
104-
CategoricalValue{V,R}}}) where {S,V,R,P,N}
105+
C::AbstractVector) where {S,V,R,P,N}
105106

106-
#ret = Array{P,N+1}(undef, size(u)..., length(C))
107107
ret = zeros(P, size(u)..., length(C))
108-
for i in eachindex(C)
109-
ret[fill(:,N)...,i] .= broadcast($func, u, C[i])
108+
# note that we do not require C to use 1-base indexing
109+
for (i, c) in enumerate(C)
110+
ret[fill(:,N)..., i] .= broadcast($func, u, c)
110111
end
111112
return ret
112113
end
@@ -126,7 +127,7 @@ end
126127
# returns `x[i]` for `Array` inputs `x`
127128
# For non-Array inputs returns `zero(dtype)`
128129
#This avoids using an if statement
129-
_getindex(x::Array,i, dtype)=x[i]
130+
_getindex(x::Array, i, dtype)=x[i]
130131
_getindex(::Nothing, i, dtype) = zero(dtype)
131132

132133
# pdf.(u, cv)
@@ -135,19 +136,23 @@ function Base.Broadcast.broadcasted(
135136
u::UniFinArr{S,V,R,P,N},
136137
cv::CategoricalValue) where {S,V,R,P,N}
137138

138-
cv in classes(u) || throw(err_missing_class(cv))
139+
# we assume that we compare categorical values by their unwrapped value
140+
# and pick the index of this value from classes(u)
141+
cv_loc = findfirst(==(cv), classes(u))
142+
cv_loc == 0 && throw(err_missing_class(cv))
139143

140144
f() = zeros(P, size(u)) #default caller function
141145

142146
return Base.Broadcast.Broadcasted(
143147
identity,
144-
(get(f, u.prob_given_ref, int(cv)),)
148+
(get(f, u.prob_given_ref, cv_loc),)
145149
)
146150
end
151+
147152
Base.Broadcast.broadcasted(
148153
::typeof(pdf),
149154
u::UniFinArr{S,V,R,P,N},
150-
::Missing) where {S,V,R,P,N} = Missings.missings(P, length(u))
155+
::Missing) where {S,V,R,P,N} = Missings.missings(P, size(u))
151156

152157
# pdf.(u, v)
153158
function Base.Broadcast.broadcasted(
@@ -160,17 +165,15 @@ function Base.Broadcast.broadcasted(
160165
length(u) == length(v) ||throw(DimensionMismatch(
161166
"Arrays could not be broadcast to a common size; "*
162167
"got a dimension with lengths $(length(u)) and $(length(v))"))
163-
for cv in v
164-
ismissing(cv) || cv in classes(u) || throw(err_missing_class(cv))
165-
end
166168

167-
# will use linear indexing:
168-
v_flat = ((v[i], i) for i in 1:length(v))
169+
v_loc_flat = [(ismissing(x) ? missing : findfirst(==(x), classes(u)), i)
170+
for (i, x) in enumerate(v)]
171+
any(isequal(0), v_loc_flat) && throw(err_missing_class(cv))
169172

170-
getter((cv, i), dtype) =
171-
_getindex(get(u.prob_given_ref, int(cv), nothing), i, dtype)
173+
getter((cv_loc, i), dtype) =
174+
_getindex(get(u.prob_given_ref, cv_loc, nothing), i, dtype)
172175
getter(::Tuple{Missing,Any}, dtype) = missing
173-
ret_flat = getter.(v_flat, P)
176+
ret_flat = getter.(v_loc_flat, P)
174177
return reshape(ret_flat, size(u))
175178
end
176179

@@ -243,10 +246,10 @@ function Base.Broadcast.broadcasted(::typeof(mode),
243246
mode_flat = map(1:length(u)) do i
244247
max_prob = maximum(dic[ref][i] for ref in keys(dic))
245248
m = zero(R)
246-
247-
# `maximum` of any iterable containing `NaN` would return `NaN`
249+
250+
# `maximum` of any iterable containing `NaN` would return `NaN`
248251
# For this case the index `m` won't be updated in the loop as relations
249-
# involving NaN as one of it's argument always returns false
252+
# involving NaN as one of it's argument always returns false
250253
# (e.g `==(NaN, NaN)` returns false)
251254
throw_nan_error_if_needed(max_prob)
252255
for ref in keys(dic)
@@ -269,9 +272,7 @@ const ERR_EMPTY_UNIVARIATE_FINITE = ArgumentError(
269272
"No `UnivariateFinite` object found from which to extract classes. ")
270273

271274
function classes(yhat::AbstractArray{<:Union{Missing,UnivariateFinite}})
272-
i = findfirst(x->!ismissing(x), yhat)
275+
i = findfirst(!ismissing, yhat)
273276
i === nothing && throw(ERR_EMPTY_UNIVARIATE_FINITE)
274277
return classes(yhat[i])
275278
end
276-
277-

src/methods.jl

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,40 +164,24 @@ One can also do weighted fits:
164164
165165
See also `classes`, `support`.
166166
"""
167-
function Dist.pdf(
168-
d::UnivariateFinite{S,V,R,P},
169-
cv::CategoricalValue,
170-
) where {S,V,R,P}
171-
return get(d.prob_given_ref, int(cv), zero(P))
172-
end
173-
Dist.pdf(d::UnivariateFinite{S,V}, c::V) where {S,V} = _pdf(d, c)
174-
Dist.pdf(::UnivariateFinite{S,V}, ::Missing) where {S,V} = missing
175-
176-
# Avoid method ambiguity errors with Dist >= 0.24
177-
Dist.pdf(d::UnivariateFinite{S,V}, c::V) where {S,V<:Real} = _pdf(d, c)
167+
Dist.pdf(::UnivariateFinite, ::Missing) = missing
178168

179-
function _pdf(d::UnivariateFinite, c)
169+
function Dist.pdf(d::UnivariateFinite{S,V,R,P}, c) where {S,V,R,P}
180170
_classes = classes(d)
181171
c in _classes || throw(DomainError("Value $c not in pool. "))
182172
pool = CategoricalArrays.pool(_classes)
183-
class = pool[get(pool, c)]
184-
return pdf(d, class)
173+
return get(d.prob_given_ref, get(pool, c), zero(P))
185174
end
186175

187-
Dist.logpdf(d::UnivariateFinite, cv::CategoricalValue) = log(pdf(d,cv))
188-
Dist.logpdf(d::UnivariateFinite{S,V}, c::V) where {S,V} = log(pdf(d, c))
189-
Dist.logpdf(::UnivariateFinite{S,V}, ::Missing) where {S,V} = missing
190-
191-
# Avoid method ambiguity errors with Dist >= 0.24
192-
Dist.logpdf(d::UnivariateFinite{S,V}, c::V) where {S,V<:Real} = log(pdf(d, c))
176+
Dist.logpdf(d::UnivariateFinite, c) = log(pdf(d, c))
193177

194178
function Dist.mode(d::UnivariateFinite)
195179
dic = d.prob_given_ref
196180
p = values(dic)
197181
max_prob = maximum(p)
198182
m = first(first(dic)) # mode, just some ref for now
199183

200-
# `maximum` of any iterable containing `NaN` would return `NaN`
184+
# `maximum` of any iterable containing `NaN` would return `NaN`
201185
# For this case the index `m` won't be updated in the loop below as relations
202186
# involving NaN as one of it's arguments always returns false
203187
# (e.g `==(NaN, NaN)` returns false)

test/arrays.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,4 +290,69 @@ end
290290

291291
end
292292

293+
function (x::T, y::T) where {T<:UnivariateFinite}
294+
return x.decoder == y.decoder &&
295+
x.prob_given_ref == y.prob_given_ref &&
296+
x.scitype == y.scitype
297+
end
298+
299+
@testset "CartesianIndex" begin
300+
v = categorical(["a", "b"], ordered=true)
301+
m = UnivariateFinite(v, rand(rng, 5, 2), augment=true)
302+
@test m[1, 1] m[CartesianIndex(1, 1)] m[CartesianIndex(1, 1, 1)]
303+
@test_throws BoundsError m[CartesianIndex(1)]
304+
@test all(zip(Matrix(m), copy(m), m)) do (x, y, z)
305+
return x y z
306+
end
307+
@test Matrix(m) isa Matrix
308+
# TODO: probably it would be better for copy to keep it
309+
# UnivariateFiniteArray but it would be breaking
310+
@test copy(m) isa Matrix
311+
@test similar(m) isa Matrix
312+
end
313+
314+
@testset "broadcasted pdf" begin
315+
v = categorical(["a", "b"], ordered=true)
316+
v2 = categorical(["a", "b"], ordered=true, levels=["b", "a"])
317+
x = UnivariateFinite(v, rand(rng, 5), augment=true)
318+
@test pdf.(x, v[1]) == pdf.(x, v2[1]) == pdf.(x, "a")
319+
@test pdf.(x, v[2]) == pdf.(x, v2[2]) == pdf.(x, "b")
320+
321+
x = UnivariateFinite(v, rand(rng, 5, 2), augment=true)
322+
@test size(pdf.(x, missing)) == (5, 2)
323+
324+
v3 = categorical(["a" "b"], ordered=true)
325+
v4 = categorical(["a" "b"], ordered=true, levels=["b", "a"])
326+
# note that v5 and v6 have the same shape and contents as v3 and v4
327+
# just they are Matrix{Any} not CategoricalMatrix
328+
v5 = Any[v3[1] v3[2]]
329+
v6 = Any[v4[1] v4[2]]
330+
x = UnivariateFinite(v, hcat([0.1, 0.2]), augment=true)
331+
332+
# these tests show that now we have corrected refpools
333+
# but still there is an inconsistency in behavior
334+
@test pdf.(x, v) == hcat([0.9, 0.2])
335+
@test pdf.(x, v2) == hcat([0.9, 0.2])
336+
@test pdf.(x, v3) == hcat([0.9, 0.2])
337+
@test pdf.(x, v4) == hcat([0.9, 0.2])
338+
@test pdf.(x, v5) == [0.9 0.1; 0.8 0.2]
339+
@test pdf.(x, v6) == [0.9 0.1; 0.8 0.2]
340+
end
341+
342+
@testset "pdf with various types" begin
343+
v = categorical(["a", "b"], ordered=true)
344+
a = view("a", 1:1) # quite common case when splitting strings
345+
b = view("b", 1:1)
346+
x = UnivariateFinite(v, [0.1, 0.2, 0.3], augment=true)
347+
@test pdf.(x, a) == pdf.(x, "a") == pdf.(x, v[1])
348+
@test logpdf.(x, a) == logpdf.(x, "a") == logpdf.(x, v[1])
349+
@test pdf(x, [a, b]) == pdf(x, ["a", "b"]) == pdf(x, v)
350+
@test logpdf(x, [a, b]) == logpdf(x, ["a", "b"]) == logpdf(x, v)
351+
352+
x = UnivariateFinite(v, 0.1, augment=true)
353+
@test pdf.(x, a) == pdf.(x, "a") == pdf.(x, v[1]) == 0.9
354+
@test logpdf.(x, a) == logpdf.(x, "a") == logpdf.(x, v[1]) == log(0.9)
355+
@test pdf(x, a) == pdf(x, "a") == pdf(x, v[1]) == 0.9
356+
@test logpdf(x, a) == logpdf(x, "a") == logpdf(x, v[1]) == log(0.9)
357+
end
293358
true

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ rng = StableRNGs.StableRNG(123)
99

1010
import CategoricalDistributions: classes, decoder, int
1111

12+
ambiguities_vec = Test.detect_ambiguities(CategoricalDistributions,
13+
recursive=true)
14+
if !isempty(ambiguities_vec)
15+
@warn "$(length(ambiguities_vec)) method ambiguities detected"
16+
end
17+
1218
@testset "utilities" begin
1319
@test include("utilities.jl")
1420
end

0 commit comments

Comments
 (0)