Skip to content

Commit 91536ac

Browse files
committed
move Categorical to univariate
1 parent 47d8447 commit 91536ac

File tree

2 files changed

+182
-184
lines changed

2 files changed

+182
-184
lines changed

src/multivariate.jl

Lines changed: 1 addition & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,185 +1,3 @@
1-
## Categorical ##
2-
3-
struct TuringDiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} <: DiscreteUnivariateDistribution
4-
support::Ts
5-
p::Ps
6-
7-
function TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args=true) where {
8-
T <: Real,
9-
P <: Real,
10-
Ts <: AbstractVector{T},
11-
Ps <: AbstractVector{P},
12-
}
13-
check_args || return new{T, P, Ts, Ps}(vs, ps)
14-
Distributions.@check_args(TuringDiscreteNonParametric, length(vs) == length(ps))
15-
Distributions.@check_args(TuringDiscreteNonParametric, isprobvec(ps))
16-
Distributions.@check_args(TuringDiscreteNonParametric, allunique(vs))
17-
sort_order = sortperm(vs)
18-
vs = vs[sort_order]
19-
ps = ps[sort_order]
20-
new{T, P, Ts, Ps}(vs, ps)
21-
end
22-
end
23-
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
24-
T <: Real,
25-
P <: Real,
26-
Ts <: AbstractVector{T},
27-
Ps <: AbstractVector{P},
28-
}
29-
return TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args = check_args)
30-
end
31-
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
32-
T <: Real,
33-
P <: Real,
34-
Ts <: AbstractVector{T},
35-
Ps <: SubArray,
36-
}
37-
_ps = collect(ps)
38-
_Ps = typeof(ps)
39-
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
40-
end
41-
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
42-
T <: Real,
43-
P <: Real,
44-
Ts <: AbstractVector{T},
45-
Ps <: TrackedVector{P, <:SubArray},
46-
}
47-
_ps = ps[:]
48-
_Ps = typeof(_ps)
49-
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
50-
end
51-
52-
Base.eltype(::Type{<:TuringDiscreteNonParametric{T}}) where T = T
53-
54-
# Accessors
55-
Distributions.params(d::TuringDiscreteNonParametric) = (d.support, d.p)
56-
57-
Distributions.support(d::TuringDiscreteNonParametric) = d.support
58-
59-
Distributions.probs(d::TuringDiscreteNonParametric) = d.p
60-
61-
Base.isapprox(c1::D, c2::D) where D <: TuringDiscreteNonParametric =
62-
(support(c1) support(c2) || all(support(c1) .≈ support(c2))) &&
63-
(probs(c1) probs(c2) || all(probs(c1) .≈ probs(c2)))
64-
65-
function Distributions.rand(rng::AbstractRNG, d::TuringDiscreteNonParametric{T,P}) where {T,P}
66-
x = support(d)
67-
p = probs(d)
68-
n = length(p)
69-
draw = rand(rng, P)
70-
cp = zero(P)
71-
i = 0
72-
while cp < draw && i < n
73-
cp += p[i +=1]
74-
end
75-
x[max(i,1)]
76-
end
77-
78-
Distributions.rand(d::TuringDiscreteNonParametric) = rand(GLOBAL_RNG, d)
79-
80-
Distributions.sampler(d::TuringDiscreteNonParametric) =
81-
DiscreteNonParametricSampler(support(d), probs(d))
82-
83-
Distributions.get_evalsamples(d::TuringDiscreteNonParametric, ::Float64) = support(d)
84-
85-
Distributions.pdf(d::TuringDiscreteNonParametric) = copy(probs(d))
86-
87-
# Helper functions for pdf and cdf required to fix ambiguous method
88-
# error involving [pc]df(::DisceteUnivariateDistribution, ::Int)
89-
function _pdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
90-
idx_range = searchsorted(support(d), x)
91-
if length(idx_range) > 0
92-
return probs(d)[first(idx_range)]
93-
else
94-
return zero(P)
95-
end
96-
end
97-
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Int) where T = _pdf(d, convert(T, x))
98-
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _pdf(d, convert(T, x))
99-
100-
function _cdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
101-
x > maximum(d) && return 1.0
102-
s = zero(P)
103-
ps = probs(d)
104-
stop_idx = searchsortedlast(support(d), x)
105-
for i in 1:stop_idx
106-
s += ps[i]
107-
end
108-
return s
109-
end
110-
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _cdf(d, convert(T, x))
111-
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _cdf(d, convert(T, x))
112-
113-
function _ccdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
114-
x < minimum(d) && return 1.0
115-
s = zero(P)
116-
ps = probs(d)
117-
stop_idx = searchsortedlast(support(d), x)
118-
for i in (stop_idx+1):length(ps)
119-
s += ps[i]
120-
end
121-
return s
122-
end
123-
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _ccdf(d, convert(T, x))
124-
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _ccdf(d, convert(T, x))
125-
126-
function Distributions.quantile(d::TuringDiscreteNonParametric, q::Real)
127-
0 <= q <= 1 || throw(DomainError())
128-
x = support(d)
129-
p = probs(d)
130-
k = length(x)
131-
i = 1
132-
cp = p[1]
133-
while cp < q && i < k #Note: is i < k necessary?
134-
i += 1
135-
@inbounds cp += p[i]
136-
end
137-
x[i]
138-
end
139-
140-
Base.minimum(d::TuringDiscreteNonParametric) = first(support(d))
141-
Base.maximum(d::TuringDiscreteNonParametric) = last(support(d))
142-
Distributions.insupport(d::TuringDiscreteNonParametric, x::Real) =
143-
length(searchsorted(support(d), x)) > 0
144-
145-
Distributions.mean(d::TuringDiscreteNonParametric) = dot(probs(d), support(d))
146-
147-
function Distributions.var(d::TuringDiscreteNonParametric{T}) where T
148-
m = mean(d)
149-
x = support(d)
150-
p = probs(d)
151-
k = length(x)
152-
σ² = zero(T)
153-
for i in 1:k
154-
@inbounds σ² += abs2(x[i] - m) * p[i]
155-
end
156-
σ²
157-
end
158-
159-
Distributions.mode(d::TuringDiscreteNonParametric) = support(d)[argmax(probs(d))]
160-
function Distributions.modes(d::TuringDiscreteNonParametric{T,P}) where {T,P}
161-
x = support(d)
162-
p = probs(d)
163-
k = length(x)
164-
mds = T[]
165-
max_p = zero(P)
166-
@inbounds for i in 1:k
167-
pi = p[i]
168-
xi = x[i]
169-
if pi > max_p
170-
max_p = pi
171-
mds = [xi]
172-
elseif pi == max_p
173-
push!(mds, xi)
174-
end
175-
end
176-
mds
177-
end
178-
179-
function Distributions.Categorical(p::TrackedVector; check_args = true)
180-
return TuringDiscreteNonParametric(1:length(p), p, check_args = check_args)
181-
end
182-
1831
## Dirichlet ##
1842

1853
struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
@@ -241,7 +59,7 @@ function simplex_logpdf(alpha, lmnB, x::AbstractVector)
24159
sum((alpha .- 1) .* log.(x)) - lmnB
24260
end
24361
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
244-
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])))
62+
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])) - lmnB)
24563
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
24664
sum((alpha .- 1) .* log.(c)) - lmnB
24765
end

src/univariate.jl

Lines changed: 181 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,4 +337,184 @@ function _dft_zygote(x::Vector{T}) where T
337337
end
338338
return copy(y)
339339
end
340-
=#
340+
=#
341+
342+
## Categorical ##
343+
344+
struct TuringDiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} <: DiscreteUnivariateDistribution
345+
support::Ts
346+
p::Ps
347+
348+
function TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args=true) where {
349+
T <: Real,
350+
P <: Real,
351+
Ts <: AbstractVector{T},
352+
Ps <: AbstractVector{P},
353+
}
354+
check_args || return new{T, P, Ts, Ps}(vs, ps)
355+
Distributions.@check_args(TuringDiscreteNonParametric, length(vs) == length(ps))
356+
Distributions.@check_args(TuringDiscreteNonParametric, isprobvec(ps))
357+
Distributions.@check_args(TuringDiscreteNonParametric, allunique(vs))
358+
sort_order = sortperm(vs)
359+
vs = vs[sort_order]
360+
ps = ps[sort_order]
361+
new{T, P, Ts, Ps}(vs, ps)
362+
end
363+
end
364+
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
365+
T <: Real,
366+
P <: Real,
367+
Ts <: AbstractVector{T},
368+
Ps <: AbstractVector{P},
369+
}
370+
return TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args = check_args)
371+
end
372+
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
373+
T <: Real,
374+
P <: Real,
375+
Ts <: AbstractVector{T},
376+
Ps <: SubArray,
377+
}
378+
_ps = collect(ps)
379+
_Ps = typeof(ps)
380+
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
381+
end
382+
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
383+
T <: Real,
384+
P <: Real,
385+
Ts <: AbstractVector{T},
386+
Ps <: TrackedVector{P, <:SubArray},
387+
}
388+
_ps = ps[:]
389+
_Ps = typeof(_ps)
390+
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
391+
end
392+
393+
Base.eltype(::Type{<:TuringDiscreteNonParametric{T}}) where T = T
394+
395+
# Accessors
396+
Distributions.support(d::TuringDiscreteNonParametric) = d.support
397+
398+
Distributions.probs(d::TuringDiscreteNonParametric) = d.p
399+
400+
Base.isapprox(c1::D, c2::D) where D <: TuringDiscreteNonParametric =
401+
(support(c1) support(c2) || all(support(c1) .≈ support(c2))) &&
402+
(probs(c1) probs(c2) || all(probs(c1) .≈ probs(c2)))
403+
404+
function Distributions.rand(rng::AbstractRNG, d::TuringDiscreteNonParametric{T,P}) where {T,P}
405+
x = support(d)
406+
p = probs(d)
407+
n = length(p)
408+
draw = rand(rng, P)
409+
cp = zero(P)
410+
i = 0
411+
while cp < draw && i < n
412+
cp += p[i +=1]
413+
end
414+
x[max(i,1)]
415+
end
416+
417+
Distributions.rand(d::TuringDiscreteNonParametric) = rand(GLOBAL_RNG, d)
418+
419+
Distributions.sampler(d::TuringDiscreteNonParametric) =
420+
DiscreteNonParametricSampler(support(d), probs(d))
421+
422+
Distributions.get_evalsamples(d::TuringDiscreteNonParametric, ::Float64) = support(d)
423+
424+
Distributions.pdf(d::TuringDiscreteNonParametric) = copy(probs(d))
425+
426+
# Helper functions for pdf and cdf required to fix ambiguous method
427+
# error involving [pc]df(::DisceteUnivariateDistribution, ::Int)
428+
function _pdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
429+
idx_range = searchsorted(support(d), x)
430+
if length(idx_range) > 0
431+
return probs(d)[first(idx_range)]
432+
else
433+
return zero(P)
434+
end
435+
end
436+
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Int) where T = _pdf(d, convert(T, x))
437+
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _pdf(d, convert(T, x))
438+
439+
function _cdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
440+
x > maximum(d) && return 1.0
441+
s = zero(P)
442+
ps = probs(d)
443+
stop_idx = searchsortedlast(support(d), x)
444+
for i in 1:stop_idx
445+
s += ps[i]
446+
end
447+
return s
448+
end
449+
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _cdf(d, convert(T, x))
450+
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _cdf(d, convert(T, x))
451+
452+
function _ccdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
453+
x < minimum(d) && return 1.0
454+
s = zero(P)
455+
ps = probs(d)
456+
stop_idx = searchsortedlast(support(d), x)
457+
for i in (stop_idx+1):length(ps)
458+
s += ps[i]
459+
end
460+
return s
461+
end
462+
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _ccdf(d, convert(T, x))
463+
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _ccdf(d, convert(T, x))
464+
465+
function Distributions.quantile(d::TuringDiscreteNonParametric, q::Real)
466+
0 <= q <= 1 || throw(DomainError())
467+
x = support(d)
468+
p = probs(d)
469+
k = length(x)
470+
i = 1
471+
cp = p[1]
472+
while cp < q && i < k #Note: is i < k necessary?
473+
i += 1
474+
@inbounds cp += p[i]
475+
end
476+
x[i]
477+
end
478+
479+
Base.minimum(d::TuringDiscreteNonParametric) = first(support(d))
480+
Base.maximum(d::TuringDiscreteNonParametric) = last(support(d))
481+
Distributions.insupport(d::TuringDiscreteNonParametric, x::Real) =
482+
length(searchsorted(support(d), x)) > 0
483+
484+
Distributions.mean(d::TuringDiscreteNonParametric) = dot(probs(d), support(d))
485+
486+
function Distributions.var(d::TuringDiscreteNonParametric{T}) where T
487+
m = mean(d)
488+
x = support(d)
489+
p = probs(d)
490+
k = length(x)
491+
σ² = zero(T)
492+
for i in 1:k
493+
@inbounds σ² += abs2(x[i] - m) * p[i]
494+
end
495+
σ²
496+
end
497+
498+
Distributions.mode(d::TuringDiscreteNonParametric) = support(d)[argmax(probs(d))]
499+
function Distributions.modes(d::TuringDiscreteNonParametric{T,P}) where {T,P}
500+
x = support(d)
501+
p = probs(d)
502+
k = length(x)
503+
mds = T[]
504+
max_p = zero(P)
505+
@inbounds for i in 1:k
506+
pi = p[i]
507+
xi = x[i]
508+
if pi > max_p
509+
max_p = pi
510+
mds = [xi]
511+
elseif pi == max_p
512+
push!(mds, xi)
513+
end
514+
end
515+
mds
516+
end
517+
518+
function Distributions.Categorical(p::TrackedVector; check_args = true)
519+
return TuringDiscreteNonParametric(1:length(p), p, check_args = check_args)
520+
end

0 commit comments

Comments
 (0)