Skip to content

Commit 759bedc

Browse files
authored
Merge pull request #69 from JuliaAI/dev
For a 0.1.13 release
2 parents d642ed0 + 5c19ce3 commit 759bedc

File tree

4 files changed

+98
-34
lines changed

4 files changed

+98
-34
lines changed

.github/codecov.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
coverage:
2+
status:
3+
project:
4+
default:
5+
threshold: 0.5%
6+
patch:
7+
default:
8+
target: 80%

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CategoricalDistributions"
22
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.12"
4+
version = "0.1.13"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/methods.jl

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# not for export:
2-
const UnivariateFiniteUnion =
3-
Union{UnivariateFinite, UnivariateFiniteArray}
2+
const UnivariateFiniteUnion{S,V,R,P} =
3+
Union{UnivariateFinite{S,V,R,P}, UnivariateFiniteArray{S,V,R,P}}
44

55
"""
66
classes(d::UnivariateFinite)
@@ -42,18 +42,34 @@ end
4242
raw_support(d::UnivariateFiniteUnion) = collect(keys(d.prob_given_ref))
4343

4444
"""
45-
Dist.support(d::UnivariateFinite)
46-
Dist.support(d::UnivariateFiniteArray)
45+
Distributions.support(d::UnivariateFinite)
46+
Distributions.support(d::UnivariateFiniteArray)
4747
4848
Ordered list of classes associated with non-zero probabilities.
4949
5050
v = categorical(["yes", "maybe", "no", "yes"])
5151
d = UnivariateFinite(v[1:2], [0.3, 0.7])
52-
support(d) # CategoricalArray{String,1,UInt32}["maybe", "yes"]
52+
Distributions.support(d) # CategoricalArray{String,1,UInt32}["maybe", "yes"]
5353
5454
"""
55-
Dist.support(d::UnivariateFiniteUnion) =
56-
map(d.decoder, raw_support(d))
55+
Distributions.support(d::UnivariateFiniteUnion) = classes(d)[raw_support(d)]
56+
57+
"""
58+
fast_support(d::UnivariateFinite)
59+
60+
Same as `Distributions.support(d)` except it returns a vector of `CategoricalValue`s,
61+
rather than a `CategoricalVector`. It executes faster, about five times faster for a
62+
three-class `UnivariateFinite` distribution.
63+
"""
64+
function fast_support(d::UnivariateFiniteUnion{S,V,R}) where {S,V,R}
65+
raw_support = keys(d.prob_given_ref)
66+
n = length(raw_support)
67+
ret = Vector{CategoricalValue{V,R}}(undef, n)
68+
for (i, ref) in enumerate(raw_support)
69+
ret[i] = d.decoder(ref)
70+
end
71+
ret
72+
end
5773

5874
# TODO: If I manually give a class zero probability, it will appear in
5975
# support, which is probably confusing. We may need two versions of
@@ -64,8 +80,7 @@ Dist.support(d::UnivariateFiniteUnion) =
6480
# not exported:
6581
sample_scitype(d::UnivariateFiniteUnion) = d.scitype
6682

67-
CategoricalArrays.isordered(d::UnivariateFinite) = isordered(classes(d))
68-
CategoricalArrays.isordered(u::UnivariateFiniteArray) = isordered(classes(u))
83+
CategoricalArrays.isordered(d::UnivariateFiniteUnion) = isordered(classes(d))
6984

7085

7186
## DISPLAY
@@ -96,8 +111,8 @@ probability pairs. Returns `false` otherwise.
96111
97112
"""
98113
function Base.isapprox(d1::UnivariateFinite, d2::UnivariateFinite; kwargs...)
99-
support1 = Dist.support(d1)
100-
support2 = Dist.support(d2)
114+
support1 = fast_support(d1)
115+
support2 = fast_support(d2)
101116
for c in support1
102117
c in support2 || return false
103118
isapprox(pdf(d1, c), pdf(d2, c); kwargs...) ||
@@ -107,8 +122,8 @@ function Base.isapprox(d1::UnivariateFinite, d2::UnivariateFinite; kwargs...)
107122
end
108123
function Base.isapprox(d1::UnivariateFiniteArray,
109124
d2::UnivariateFiniteArray; kwargs...)
110-
support1 = Dist.support(d1)
111-
support2 = Dist.support(d2)
125+
support1 = fast_support(d1)
126+
support2 = fast_support(d2)
112127
for c in support1
113128
c in support2 || return false
114129
isapprox(pdf.(d1, c), pdf.(d2, c); kwargs...) ||
@@ -206,22 +221,18 @@ function throw_nan_error_if_needed(x)
206221
end
207222
end
208223

209-
# mode(v::Vector{UnivariateFinite}) = mode.(v)
210-
# mode(u::UnivariateFiniteVector{2}) =
211-
# [u.support[ifelse(s > 0.5, 2, 1)] for s in u.scores]
212-
# mode(u::UnivariateFiniteVector{C}) where {C} =
213-
# [u.support[findmax(s)[2]] for s in eachrow(u.scores)]
224+
225+
# # HELPERS FOR RAND
214226

215227
"""
216228
_cumulative(d::UnivariateFinite)
217229
218230
**Private method.**
219231
220-
Return the cumulative probability vector `C` for the distribution `d`,
221-
using only classes in the support of `d`, ordered according to the
222-
categorical elements used at instantiation of `d`. Used only to
223-
implement random sampling from `d`. We have `C[1] == 0` and `C[end] ==
224-
1`, assuming the probabilities have been normalized.
232+
Return the cumulative probability vector `C` for the distribution `d`, using only classes
233+
in `Distributions.support(d)`, ordered according to the categorical elements used at
234+
instantiation of `d`. Used only to implement random sampling from `d`. We have `C[1] == 0`
235+
and `C[end] == 1`, assuming the probabilities have been normalized.
225236
226237
"""
227238
function _cumulative(d::UnivariateFinite{S,V,R,P}) where {S,V,R,P<:Real}
@@ -260,16 +271,54 @@ function _rand(rng, p_cumulative, R)
260271
return index
261272
end
262273

263-
Random.eltype(::Type{<:UnivariateFinite{<:Any,V}}) where V = V
274+
275+
# # RAND
276+
277+
Random.eltype(::Type{<:UnivariateFinite{S,V,R}}) where {S,V,R} =
278+
CategoricalArrays.CategoricalValue{V,R}
264279

265280
# The Sampler hook into Random's API is discussed in the Julia documentation, in the
266281
# Standard Library section on Random.
282+
283+
284+
## Single samples
285+
286+
Random.Sampler(::AbstractRNG, d::UnivariateFinite, ::Val{1}) = Random.SamplerTrivial(d)
287+
288+
function Base.rand(
289+
rng::AbstractRNG,
290+
sampler::Random.SamplerTrivial{<:UnivariateFinite{<:Any,<:Any,V,P}},
291+
) where {V, P}
292+
293+
d = sampler[]
294+
u = rand(rng)
295+
296+
total = zero(P)
297+
298+
# For type stability we assign `zero(V)`` as the default ref
299+
# This isn't a problem since we know that `rand` is always defined
300+
# as UnivariateFinite objects have non-negative probabilities,
301+
# summing up to a non-negative value.
302+
rng_key = zero(V)
303+
for (ref, prob) in pairs(d.prob_given_ref)
304+
total += prob
305+
u <= total && begin
306+
rng_key = ref
307+
break
308+
end
309+
end
310+
return d.decoder(rng_key)
311+
end
312+
313+
314+
## Multiple samples
315+
267316
function Random.Sampler(
268317
::AbstractRNG,
269318
d::UnivariateFinite,
270319
::Random.Repetition,
271320
)
272-
data = (_cumulative(d), Dist.support(d))
321+
data = (_cumulative(d), fast_support(d))
273322
Random.SamplerSimple(d, data)
274323
end
275324

@@ -281,6 +330,9 @@ function Base.rand(
281330
return support[_rand(rng, p_cumulative, R)]
282331
end
283332

333+
334+
## FIT
335+
284336
function Dist.fit(d::Type{<:UnivariateFinite},
285337
v::AbstractVector{C}) where C
286338
C <: CategoricalValue ||

test/methods.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
2828
@test classes(d) == classes(s)
2929
@test levels(d) == levels(s)
3030
@test support(d) == [f, q, s]
31+
@test support(d) == [CategoricalDistributions.fast_support(d)...]
3132
@test CategoricalDistributions.sample_scitype(d) == OrderedFactor{4}
3233
# levels!(v, reverse(levels(v)))
3334
# @test classes(d) == [s, q, f, a]
@@ -54,7 +55,7 @@ A, S, Q, F = V[1], V[2], V[3], V[4]
5455

5556
N = 50
5657
rng = StableRNG(125)
57-
samples = [rand(rng,d) for i in 1:50];
58+
samples = [rand(rng, d) for i in 1:N];
5859
rng = StableRNG(125)
5960
@test samples == [rand(rng, d) for i in 1:N]
6061

@@ -301,15 +302,18 @@ end
301302
end
302303

303304
@testset "rand signatures" begin
304-
d = UnivariateFinite(
305-
["maybe", "no", "yes"],
306-
[0.5, 0.4, 0.1];
307-
pool=missing,
308-
)
305+
dict = Dict(s=>0.1, q=>0.2, f=>0.7)
306+
d = UnivariateFinite(dict)
309307

310-
# smoke test:
311308
sampler = Random.Sampler(default_rng(), d, Val(1))
312-
rand(default_rng(), sampler)
309+
@test sampler isa Random.SamplerTrivial
310+
sampler = Random.Sampler(default_rng(), d, Val(Inf))
311+
@test sampler isa Random.SamplerSimple
312+
313+
# sampling one at a time, or all at once is the same:
314+
rng0 = StableRNG(123)
315+
samples = [rand(rng0, d) for i in 1:30]
316+
@test samples == [rand(StableRNG(123), d, 30)...]
313317

314318
Random.seed!(123)
315319
samples = [rand(default_rng(), d) for i in 1:30]

0 commit comments

Comments
 (0)