Skip to content

Commit 793f90c

Browse files
authored
Merge branch 'dev' into indexing
2 parents 35382f4 + e8e4aa6 commit 793f90c

File tree

10 files changed

+155
-62
lines changed

10 files changed

+155
-62
lines changed

.github/workflows/TagBot.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,22 @@ on:
44
types:
55
- created
66
workflow_dispatch:
7+
inputs:
8+
lookback:
9+
default: 3
10+
permissions:
11+
actions: read
12+
checks: read
13+
contents: write
14+
deployments: read
15+
issues: read
16+
discussions: read
17+
packages: read
18+
pages: read
19+
pull-requests: read
20+
repository-projects: read
21+
security-events: read
22+
statuses: read
723
jobs:
824
TagBot:
925
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.3'
2120
- '1.6'
2221
- '1' # automatically expands to the latest stable 1.x release of Julia.
22+
- 'nightly'
2323
os:
2424
- ubuntu-latest
2525
arch:
2626
- x64
27+
include:
28+
- os: windows-latest
29+
version: '1'
30+
arch: x86
2731
steps:
2832
- uses: actions/checkout@v2
2933
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.7"
4+
version = "0.1.10"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -19,7 +19,7 @@ Missings = "0.4, 1"
1919
OrderedCollections = "1.1"
2020
ScientificTypes = "3.0"
2121
UnicodePlots = "2, 3"
22-
julia = "1.3"
22+
julia = "1.6"
2323

2424
[extras]
2525
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/arithmetic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ## ARITHMETIC
22

33
const ERR_DIFFERENT_SAMPLE_SPACES = ArgumentError(
4-
"Adding two `UnivariateFinite` objects whose "*
4+
"Adding two `UnivariateFinite` objects whose " *
55
"sample spaces have different labellings is not allowed. ")
66

77
import Base: +, *, /, -

src/arrays.jl

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function Base.getindex(u::UniFinArr{<:Any,<:Any,R, P}, i...) where {R, P}
2626
ref_probs = Vector{P}(undef, n_refs)
2727
unf_constructor = UnivariateFinite
2828
end
29-
29+
3030
# Fill in the first elements
3131
# Both `refs` and `ref_probs` are both of type `Vector` and hence support
3232
# linear indexing with index starting at `1`
@@ -49,6 +49,11 @@ function Base.getindex(u::UniFinArr{<:Any,<:Any,R, P}, i...) where {R, P}
4949
return unf_constructor(u.scitype, u.decoder, prob_given_ref)
5050
end
5151

52+
function Base.getindex(u::UniFinArr, idx::CartesianIndex)
53+
checkbounds(u, idx)
54+
return u[Tuple(idx)...]
55+
end
56+
5257
function Base.setindex!(u::UniFinArr{S,V,R,P,N},
5358
v::UnivariateFinite{S,V,R,P},
5459
i::Integer...) where {S,V,R,P,N}
@@ -61,9 +66,9 @@ end
6166
# TODO: return an exception without throwing it:
6267

6368
_err_incompatible_levels() = throw(DomainError(
64-
"Cannot concatenate `UnivariateFiniteArray`s with "*
65-
"different categorical levels (classes), "*
66-
"or whose levels, when ordered, are not "*
69+
"Cannot concatenate `UnivariateFiniteArray`s with " *
70+
"different categorical levels (classes), " *
71+
"or whose levels, when ordered, are not " *
6772
"consistently ordered. "))
6873

6974
# terminology:
@@ -87,14 +92,12 @@ function Base.cat(us::UniFinArr{S,V,R,P,N}...;
8792
for i in 2:length(us)
8893
isordered(us[i]) == ordered || _err_incompatible_levels()
8994
if ordered
90-
classes(us[i]) ==
91-
_classes|| _err_incompatible_levels()
95+
classes(us[i]) == _classes || _err_incompatible_levels()
9296
else
93-
Set(classes(us[i])) ==
94-
Set(_classes) || _err_incompatible_levels()
97+
Set(classes(us[i])) == Set(_classes) || _err_incompatible_levels()
9598
end
96-
support_with_duplicates =
97-
vcat(support_with_duplicates, Dist.support(us[i]))
99+
support_with_duplicates = vcat(support_with_duplicates,
100+
Dist.support(us[i]))
98101
end
99102
_support = unique(support_with_duplicates) # no-longer categorical!
100103

@@ -125,14 +128,12 @@ for func in [:pdf, :logpdf]
125128
eval(quote
126129
function Distributions.$func(
127130
u::AbstractArray{UnivariateFinite{S,V,R,P},N},
128-
C::AbstractVector{<:Union{
129-
V,
130-
CategoricalValue{V,R}}}) where {S,V,R,P,N}
131+
C::AbstractVector) where {S,V,R,P,N}
131132

132-
#ret = Array{P,N+1}(undef, size(u)..., length(C))
133133
ret = zeros(P, size(u)..., length(C))
134-
for i in eachindex(C)
135-
ret[fill(:,N)...,i] .= broadcast($func, u, C[i])
134+
# note that we do not require C to use 1-base indexing
135+
for (i, c) in enumerate(C)
136+
ret[fill(:,N)..., i] .= broadcast($func, u, c)
136137
end
137138
return ret
138139
end
@@ -152,7 +153,7 @@ end
152153
# returns `x[i]` for `Array` inputs `x`
153154
# For non-Array inputs returns `zero(dtype)`
154155
#This avoids using an if statement
155-
_getindex(x::Array,i, dtype)=x[i]
156+
_getindex(x::Array, i, dtype)=x[i]
156157
_getindex(::Nothing, i, dtype) = zero(dtype)
157158

158159
# pdf.(u, cv)
@@ -161,19 +162,23 @@ function Base.Broadcast.broadcasted(
161162
u::UniFinArr{S,V,R,P,N},
162163
cv::CategoricalValue) where {S,V,R,P,N}
163164

164-
cv in classes(u) || throw(err_missing_class(cv))
165+
# we assume that we compare categorical values by their unwrapped value
166+
# and pick the index of this value from classes(u)
167+
cv_loc = findfirst(==(cv), classes(u))
168+
cv_loc == 0 && throw(err_missing_class(cv))
165169

166170
f() = zeros(P, size(u)) #default caller function
167171

168172
return Base.Broadcast.Broadcasted(
169173
identity,
170-
(get(f, u.prob_given_ref, int(cv)),)
174+
(get(f, u.prob_given_ref, cv_loc),)
171175
)
172176
end
177+
173178
Base.Broadcast.broadcasted(
174179
::typeof(pdf),
175180
u::UniFinArr{S,V,R,P,N},
176-
::Missing) where {S,V,R,P,N} = Missings.missings(P, length(u))
181+
::Missing) where {S,V,R,P,N} = Missings.missings(P, size(u))
177182

178183
# pdf.(u, v)
179184
function Base.Broadcast.broadcasted(
@@ -186,17 +191,15 @@ function Base.Broadcast.broadcasted(
186191
length(u) == length(v) ||throw(DimensionMismatch(
187192
"Arrays could not be broadcast to a common size; "*
188193
"got a dimension with lengths $(length(u)) and $(length(v))"))
189-
for cv in v
190-
ismissing(cv) || cv in classes(u) || throw(err_missing_class(cv))
191-
end
192194

193-
# will use linear indexing:
194-
v_flat = ((v[i], i) for i in 1:length(v))
195+
v_loc_flat = [(ismissing(x) ? missing : findfirst(==(x), classes(u)), i)
196+
for (i, x) in enumerate(v)]
197+
any(isequal(0), v_loc_flat) && throw(err_missing_class(cv))
195198

196-
getter((cv, i), dtype) =
197-
_getindex(get(u.prob_given_ref, int(cv), nothing), i, dtype)
199+
getter((cv_loc, i), dtype) =
200+
_getindex(get(u.prob_given_ref, cv_loc, nothing), i, dtype)
198201
getter(::Tuple{Missing,Any}, dtype) = missing
199-
ret_flat = getter.(v_flat, P)
202+
ret_flat = getter.(v_loc_flat, P)
200203
return reshape(ret_flat, size(u))
201204
end
202205

@@ -269,10 +272,10 @@ function Base.Broadcast.broadcasted(::typeof(mode),
269272
mode_flat = map(1:length(u)) do i
270273
max_prob = maximum(dic[ref][i] for ref in keys(dic))
271274
m = zero(R)
272-
273-
# `maximum` of any iterable containing `NaN` would return `NaN`
275+
276+
# `maximum` of any iterable containing `NaN` would return `NaN`
274277
# For this case the index `m` won't be updated in the loop as relations
275-
# involving NaN as one of it's argument always returns false
278+
# involving NaN as one of it's argument always returns false
276279
# (e.g `==(NaN, NaN)` returns false)
277280
throw_nan_error_if_needed(max_prob)
278281
for ref in keys(dic)
@@ -295,9 +298,7 @@ const ERR_EMPTY_UNIVARIATE_FINITE = ArgumentError(
295298
"No `UnivariateFinite` object found from which to extract classes. ")
296299

297300
function classes(yhat::AbstractArray{<:Union{Missing,UnivariateFinite}})
298-
i = findfirst(x->!ismissing(x), yhat)
301+
i = findfirst(!ismissing, yhat)
299302
i === nothing && throw(ERR_EMPTY_UNIVARIATE_FINITE)
300303
return classes(yhat[i])
301304
end
302-
303-

src/methods.jl

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,40 +164,24 @@ One can also do weighted fits:
164164
165165
See also `classes`, `support`.
166166
"""
167-
function Dist.pdf(
168-
d::UnivariateFinite{S,V,R,P},
169-
cv::CategoricalValue,
170-
) where {S,V,R,P}
171-
return get(d.prob_given_ref, int(cv), zero(P))
172-
end
173-
Dist.pdf(d::UnivariateFinite{S,V}, c::V) where {S,V} = _pdf(d, c)
174-
Dist.pdf(::UnivariateFinite{S,V}, ::Missing) where {S,V} = missing
175-
176-
# Avoid method ambiguity errors with Dist >= 0.24
177-
Dist.pdf(d::UnivariateFinite{S,V}, c::V) where {S,V<:Real} = _pdf(d, c)
167+
Dist.pdf(::UnivariateFinite, ::Missing) = missing
178168

179-
function _pdf(d::UnivariateFinite, c)
169+
function Dist.pdf(d::UnivariateFinite{S,V,R,P}, c) where {S,V,R,P}
180170
_classes = classes(d)
181171
c in _classes || throw(DomainError("Value $c not in pool. "))
182172
pool = CategoricalArrays.pool(_classes)
183-
class = pool[get(pool, c)]
184-
return pdf(d, class)
173+
return get(d.prob_given_ref, get(pool, c), zero(P))
185174
end
186175

187-
Dist.logpdf(d::UnivariateFinite, cv::CategoricalValue) = log(pdf(d,cv))
188-
Dist.logpdf(d::UnivariateFinite{S,V}, c::V) where {S,V} = log(pdf(d, c))
189-
Dist.logpdf(::UnivariateFinite{S,V}, ::Missing) where {S,V} = missing
190-
191-
# Avoid method ambiguity errors with Dist >= 0.24
192-
Dist.logpdf(d::UnivariateFinite{S,V}, c::V) where {S,V<:Real} = log(pdf(d, c))
176+
Dist.logpdf(d::UnivariateFinite, c) = log(pdf(d, c))
193177

194178
function Dist.mode(d::UnivariateFinite)
195179
dic = d.prob_given_ref
196180
p = values(dic)
197181
max_prob = maximum(p)
198182
m = first(first(dic)) # mode, just some ref for now
199183

200-
# `maximum` of any iterable containing `NaN` would return `NaN`
184+
# `maximum` of any iterable containing `NaN` would return `NaN`
201185
# For this case the index `m` won't be updated in the loop below as relations
202186
# involving NaN as one of it's arguments always returns false
203187
# (e.g `==(NaN, NaN)` returns false)

src/types.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,11 @@ function _UnivariateFinite(support,
456456
issubset(support, _classes) ||
457457
error("Specified support, $support, not contained in "*
458458
"specified pool, $(levels(classes)). ")
459-
_support = filter(_classes) do c
460-
c in support
461-
end
459+
idxs = getindex.(
460+
Ref(CategoricalArrays.DataAPI.invrefpool(_classes)),
461+
support
462+
)
463+
_support = _classes[idxs]
462464
end
463465

464466
# calls core method:

test/arrays.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,4 +310,69 @@ end
310310

311311
end
312312

313+
function (x::T, y::T) where {T<:UnivariateFinite}
314+
return x.decoder == y.decoder &&
315+
x.prob_given_ref == y.prob_given_ref &&
316+
x.scitype == y.scitype
317+
end
318+
319+
@testset "CartesianIndex" begin
320+
v = categorical(["a", "b"], ordered=true)
321+
m = UnivariateFinite(v, rand(rng, 5, 2), augment=true)
322+
@test m[1, 1] m[CartesianIndex(1, 1)] m[CartesianIndex(1, 1, 1)]
323+
@test_throws BoundsError m[CartesianIndex(1)]
324+
@test all(zip(Matrix(m), copy(m), m)) do (x, y, z)
325+
return x y z
326+
end
327+
@test Matrix(m) isa Matrix
328+
# TODO: probably it would be better for copy to keep it
329+
# UnivariateFiniteArray but it would be breaking
330+
@test copy(m) isa Matrix
331+
@test similar(m) isa Matrix
332+
end
333+
334+
@testset "broadcasted pdf" begin
335+
v = categorical(["a", "b"], ordered=true)
336+
v2 = categorical(["a", "b"], ordered=true, levels=["b", "a"])
337+
x = UnivariateFinite(v, rand(rng, 5), augment=true)
338+
@test pdf.(x, v[1]) == pdf.(x, v2[1]) == pdf.(x, "a")
339+
@test pdf.(x, v[2]) == pdf.(x, v2[2]) == pdf.(x, "b")
340+
341+
x = UnivariateFinite(v, rand(rng, 5, 2), augment=true)
342+
@test size(pdf.(x, missing)) == (5, 2)
343+
344+
v3 = categorical(["a" "b"], ordered=true)
345+
v4 = categorical(["a" "b"], ordered=true, levels=["b", "a"])
346+
# note that v5 and v6 have the same shape and contents as v3 and v4
347+
# just they are Matrix{Any} not CategoricalMatrix
348+
v5 = Any[v3[1] v3[2]]
349+
v6 = Any[v4[1] v4[2]]
350+
x = UnivariateFinite(v, hcat([0.1, 0.2]), augment=true)
351+
352+
# these tests show that now we have corrected refpools
353+
# but still there is an inconsistency in behavior
354+
@test pdf.(x, v) == hcat([0.9, 0.2])
355+
@test pdf.(x, v2) == hcat([0.9, 0.2])
356+
@test pdf.(x, v3) == hcat([0.9, 0.2])
357+
@test pdf.(x, v4) == hcat([0.9, 0.2])
358+
@test pdf.(x, v5) == [0.9 0.1; 0.8 0.2]
359+
@test pdf.(x, v6) == [0.9 0.1; 0.8 0.2]
360+
end
361+
362+
@testset "pdf with various types" begin
363+
v = categorical(["a", "b"], ordered=true)
364+
a = view("a", 1:1) # quite common case when splitting strings
365+
b = view("b", 1:1)
366+
x = UnivariateFinite(v, [0.1, 0.2, 0.3], augment=true)
367+
@test pdf.(x, a) == pdf.(x, "a") == pdf.(x, v[1])
368+
@test logpdf.(x, a) == logpdf.(x, "a") == logpdf.(x, v[1])
369+
@test pdf(x, [a, b]) == pdf(x, ["a", "b"]) == pdf(x, v)
370+
@test logpdf(x, [a, b]) == logpdf(x, ["a", "b"]) == logpdf(x, v)
371+
372+
x = UnivariateFinite(v, 0.1, augment=true)
373+
@test pdf.(x, a) == pdf.(x, "a") == pdf.(x, v[1]) == 0.9
374+
@test logpdf.(x, a) == logpdf.(x, "a") == logpdf.(x, v[1]) == log(0.9)
375+
@test pdf(x, a) == pdf(x, "a") == pdf(x, v[1]) == 0.9
376+
@test logpdf(x, a) == logpdf(x, "a") == logpdf(x, v[1]) == log(0.9)
377+
end
313378
true

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ rng = StableRNGs.StableRNG(123)
99

1010
import CategoricalDistributions: classes, decoder, int
1111

12+
ambiguities_vec = Test.detect_ambiguities(CategoricalDistributions,
13+
recursive=true)
14+
if !isempty(ambiguities_vec)
15+
@warn "$(length(ambiguities_vec)) method ambiguities detected"
16+
end
17+
1218
@testset "utilities" begin
1319
@test include("utilities.jl")
1420
end

test/types.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using StableRNGs
77
using FillArrays
88
using ScientificTypes
99
import Random
10+
import CategoricalDistributions: classes
1011

1112
# coverage of constructor testing is expanded in the other test files
1213

@@ -36,6 +37,20 @@ end
3637

3738
UnivariateFinite(supp, probs, pool=missing, augment=true);
3839

40+
# construction from pool and support does not
41+
# consist of categorical elements (See issue #34)
42+
v = categorical(["x", "x", "y", "z", "y", "z", "p"])
43+
probs1 = [0.1, 0.2, 0.7]
44+
probs2 = [0.1 0.2 0.7; 0.5 0.2 0.3; 0.8 0.1 0.1]
45+
unf1 = UnivariateFinite(["y", "x", "z"], probs1, pool=v)
46+
unf2 = UnivariateFinite(["y", "x", "z"], probs2, pool=v)
47+
@test CategoricalArrays.pool(classes(unf1)) == CategoricalArrays.pool(v)
48+
@test CategoricalArrays.pool(classes(unf2)) == CategoricalArrays.pool(v)
49+
@test pdf.(unf1, ["y", "x", "z"]) == probs1
50+
@test pdf.(unf2, "y") == probs2[:, 1]
51+
@test pdf.(unf2, "x") == probs2[:, 2]
52+
@test pdf.(unf2, "z") == probs2[:, 3]
53+
3954
# dimension mismatches:
4055
badprobs = rand(rng, 40, 3)
4156
@test_throws(CategoricalDistributions.err_dim(supp, badprobs),

0 commit comments

Comments
 (0)