Skip to content

Commit 5a8698e

Browse files
committed
fix Categorical
1 parent f74f95d commit 5a8698e

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed

src/multivariate.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,184 @@
1+
## Categorical ##
2+
3+
struct TuringDiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} <: DiscreteUnivariateDistribution
4+
support::Ts
5+
p::Ps
6+
7+
function TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args=true) where {
8+
T <: Real,
9+
P <: Real,
10+
Ts <: AbstractVector{T},
11+
Ps <: AbstractVector{P},
12+
}
13+
check_args || return new{T, P, Ts, Ps}(vs, ps)
14+
Distributions.@check_args(TuringDiscreteNonParametric, length(vs) == length(ps))
15+
Distributions.@check_args(TuringDiscreteNonParametric, isprobvec(ps))
16+
Distributions.@check_args(TuringDiscreteNonParametric, allunique(vs))
17+
sort_order = sortperm(vs)
18+
vs = vs[sort_order]
19+
ps = ps[sort_order]
20+
new{T, P, Ts, Ps}(vs, ps)
21+
end
22+
end
23+
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
24+
T <: Real,
25+
P <: Real,
26+
Ts <: AbstractVector{T},
27+
Ps <: AbstractVector{P},
28+
}
29+
return TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args = check_args)
30+
end
31+
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
32+
T <: Real,
33+
P <: Real,
34+
Ts <: AbstractVector{T},
35+
Ps <: SubArray,
36+
}
37+
_ps = collect(ps)
38+
_Ps = typeof(ps)
39+
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
40+
end
41+
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
42+
T <: Real,
43+
P <: Real,
44+
Ts <: AbstractVector{T},
45+
Ps <: TrackedVector{P, <:SubArray},
46+
}
47+
_ps = ps[:]
48+
_Ps = typeof(_ps)
49+
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
50+
end
51+
52+
Base.eltype(::Type{<:TuringDiscreteNonParametric{T}}) where T = T
53+
54+
# Accessors
55+
Distributions.params(d::TuringDiscreteNonParametric) = (d.support, d.p)
56+
57+
Distributions.support(d::TuringDiscreteNonParametric) = d.support
58+
59+
Distributions.probs(d::TuringDiscreteNonParametric) = d.p
60+
61+
Base.isapprox(c1::D, c2::D) where D <: TuringDiscreteNonParametric =
62+
(support(c1) support(c2) || all(support(c1) .≈ support(c2))) &&
63+
(probs(c1) probs(c2) || all(probs(c1) .≈ probs(c2)))
64+
65+
function Distributions.rand(rng::AbstractRNG, d::TuringDiscreteNonParametric{T,P}) where {T,P}
66+
x = support(d)
67+
p = probs(d)
68+
n = length(p)
69+
draw = rand(rng, P)
70+
cp = zero(P)
71+
i = 0
72+
while cp < draw && i < n
73+
cp += p[i +=1]
74+
end
75+
x[max(i,1)]
76+
end
77+
78+
Distributions.rand(d::TuringDiscreteNonParametric) = rand(GLOBAL_RNG, d)
79+
80+
Distributions.sampler(d::TuringDiscreteNonParametric) =
81+
DiscreteNonParametricSampler(support(d), probs(d))
82+
83+
Distributions.get_evalsamples(d::TuringDiscreteNonParametric, ::Float64) = support(d)
84+
85+
Distributions.pdf(d::TuringDiscreteNonParametric) = copy(probs(d))
86+
87+
# Helper functions for pdf and cdf required to fix ambiguous method
88+
# error involving [pc]df(::DisceteUnivariateDistribution, ::Int)
89+
function _pdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
90+
idx_range = searchsorted(support(d), x)
91+
if length(idx_range) > 0
92+
return probs(d)[first(idx_range)]
93+
else
94+
return zero(P)
95+
end
96+
end
97+
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Int) where T = _pdf(d, convert(T, x))
98+
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _pdf(d, convert(T, x))
99+
100+
function _cdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
101+
x > maximum(d) && return 1.0
102+
s = zero(P)
103+
ps = probs(d)
104+
stop_idx = searchsortedlast(support(d), x)
105+
for i in 1:stop_idx
106+
s += ps[i]
107+
end
108+
return s
109+
end
110+
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _cdf(d, convert(T, x))
111+
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _cdf(d, convert(T, x))
112+
113+
function _ccdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
114+
x < minimum(d) && return 1.0
115+
s = zero(P)
116+
ps = probs(d)
117+
stop_idx = searchsortedlast(support(d), x)
118+
for i in (stop_idx+1):length(ps)
119+
s += ps[i]
120+
end
121+
return s
122+
end
123+
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _ccdf(d, convert(T, x))
124+
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _ccdf(d, convert(T, x))
125+
126+
function Distributions.quantile(d::TuringDiscreteNonParametric, q::Real)
127+
0 <= q <= 1 || throw(DomainError())
128+
x = support(d)
129+
p = probs(d)
130+
k = length(x)
131+
i = 1
132+
cp = p[1]
133+
while cp < q && i < k #Note: is i < k necessary?
134+
i += 1
135+
@inbounds cp += p[i]
136+
end
137+
x[i]
138+
end
139+
140+
Base.minimum(d::TuringDiscreteNonParametric) = first(support(d))
141+
Base.maximum(d::TuringDiscreteNonParametric) = last(support(d))
142+
Distributions.insupport(d::TuringDiscreteNonParametric, x::Real) =
143+
length(searchsorted(support(d), x)) > 0
144+
145+
Distributions.mean(d::TuringDiscreteNonParametric) = dot(probs(d), support(d))
146+
147+
function Distributions.var(d::TuringDiscreteNonParametric{T}) where T
148+
m = mean(d)
149+
x = support(d)
150+
p = probs(d)
151+
k = length(x)
152+
σ² = zero(T)
153+
for i in 1:k
154+
@inbounds σ² += abs2(x[i] - m) * p[i]
155+
end
156+
σ²
157+
end
158+
159+
Distributions.mode(d::TuringDiscreteNonParametric) = support(d)[argmax(probs(d))]
160+
function Distributions.modes(d::TuringDiscreteNonParametric{T,P}) where {T,P}
161+
x = support(d)
162+
p = probs(d)
163+
k = length(x)
164+
mds = T[]
165+
max_p = zero(P)
166+
@inbounds for i in 1:k
167+
pi = p[i]
168+
xi = x[i]
169+
if pi > max_p
170+
max_p = pi
171+
mds = [xi]
172+
elseif pi == max_p
173+
push!(mds, xi)
174+
end
175+
end
176+
mds
177+
end
178+
179+
function Distributions.Categorical(p::TrackedVector; check_args = true)
180+
return TuringDiscreteNonParametric(1:length(p), p, check_args = check_args)
181+
end
1182
## MvNormal ##
2183

3184
"""

0 commit comments

Comments
 (0)