Skip to content

Commit f4b5068

Browse files
committed
simplify Categorical fix
1 parent 91536ac commit f4b5068

File tree

1 file changed

+5
-175
lines changed

1 file changed

+5
-175
lines changed

src/univariate.jl

Lines changed: 5 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -341,180 +341,10 @@ end
341341

342342
## Categorical ##
343343

344-
struct TuringDiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} <: DiscreteUnivariateDistribution
345-
support::Ts
346-
p::Ps
347-
348-
function TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args=true) where {
349-
T <: Real,
350-
P <: Real,
351-
Ts <: AbstractVector{T},
352-
Ps <: AbstractVector{P},
353-
}
354-
check_args || return new{T, P, Ts, Ps}(vs, ps)
355-
Distributions.@check_args(TuringDiscreteNonParametric, length(vs) == length(ps))
356-
Distributions.@check_args(TuringDiscreteNonParametric, isprobvec(ps))
357-
Distributions.@check_args(TuringDiscreteNonParametric, allunique(vs))
358-
sort_order = sortperm(vs)
359-
vs = vs[sort_order]
360-
ps = ps[sort_order]
361-
new{T, P, Ts, Ps}(vs, ps)
362-
end
363-
end
364-
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
365-
T <: Real,
366-
P <: Real,
367-
Ts <: AbstractVector{T},
368-
Ps <: AbstractVector{P},
369-
}
370-
return TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args = check_args)
371-
end
372-
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
373-
T <: Real,
374-
P <: Real,
375-
Ts <: AbstractVector{T},
376-
Ps <: SubArray,
377-
}
378-
_ps = collect(ps)
379-
_Ps = typeof(ps)
380-
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
381-
end
382-
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
383-
T <: Real,
384-
P <: Real,
385-
Ts <: AbstractVector{T},
386-
Ps <: TrackedVector{P, <:SubArray},
387-
}
388-
_ps = ps[:]
389-
_Ps = typeof(_ps)
390-
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
391-
end
392-
393-
Base.eltype(::Type{<:TuringDiscreteNonParametric{T}}) where T = T
394-
395-
# Accessors
396-
Distributions.support(d::TuringDiscreteNonParametric) = d.support
397-
398-
Distributions.probs(d::TuringDiscreteNonParametric) = d.p
399-
400-
Base.isapprox(c1::D, c2::D) where D <: TuringDiscreteNonParametric =
401-
(support(c1) support(c2) || all(support(c1) .≈ support(c2))) &&
402-
(probs(c1) probs(c2) || all(probs(c1) .≈ probs(c2)))
403-
404-
function Distributions.rand(rng::AbstractRNG, d::TuringDiscreteNonParametric{T,P}) where {T,P}
405-
x = support(d)
406-
p = probs(d)
407-
n = length(p)
408-
draw = rand(rng, P)
409-
cp = zero(P)
410-
i = 0
411-
while cp < draw && i < n
412-
cp += p[i +=1]
413-
end
414-
x[max(i,1)]
415-
end
416-
417-
Distributions.rand(d::TuringDiscreteNonParametric) = rand(GLOBAL_RNG, d)
418-
419-
Distributions.sampler(d::TuringDiscreteNonParametric) =
420-
DiscreteNonParametricSampler(support(d), probs(d))
421-
422-
Distributions.get_evalsamples(d::TuringDiscreteNonParametric, ::Float64) = support(d)
423-
424-
Distributions.pdf(d::TuringDiscreteNonParametric) = copy(probs(d))
425-
426-
# Helper functions for pdf and cdf required to fix ambiguous method
427-
# error involving [pc]df(::DisceteUnivariateDistribution, ::Int)
428-
function _pdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
429-
idx_range = searchsorted(support(d), x)
430-
if length(idx_range) > 0
431-
return probs(d)[first(idx_range)]
432-
else
433-
return zero(P)
434-
end
344+
function Base.convert(
345+
::Type{Distributions.DiscreteNonParametric{T,P,Ts,Ps}},
346+
d::Distributions.DiscreteNonParametric{T,P,Ts,Ps},
347+
) where {T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}}
348+
DiscreteNonParametric{T,P,Ts,Ps}(support(d), probs(d), check_args=false)
435349
end
436-
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Int) where T = _pdf(d, convert(T, x))
437-
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _pdf(d, convert(T, x))
438350

439-
function _cdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
440-
x > maximum(d) && return 1.0
441-
s = zero(P)
442-
ps = probs(d)
443-
stop_idx = searchsortedlast(support(d), x)
444-
for i in 1:stop_idx
445-
s += ps[i]
446-
end
447-
return s
448-
end
449-
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _cdf(d, convert(T, x))
450-
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _cdf(d, convert(T, x))
451-
452-
function _ccdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
453-
x < minimum(d) && return 1.0
454-
s = zero(P)
455-
ps = probs(d)
456-
stop_idx = searchsortedlast(support(d), x)
457-
for i in (stop_idx+1):length(ps)
458-
s += ps[i]
459-
end
460-
return s
461-
end
462-
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _ccdf(d, convert(T, x))
463-
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _ccdf(d, convert(T, x))
464-
465-
function Distributions.quantile(d::TuringDiscreteNonParametric, q::Real)
466-
0 <= q <= 1 || throw(DomainError())
467-
x = support(d)
468-
p = probs(d)
469-
k = length(x)
470-
i = 1
471-
cp = p[1]
472-
while cp < q && i < k #Note: is i < k necessary?
473-
i += 1
474-
@inbounds cp += p[i]
475-
end
476-
x[i]
477-
end
478-
479-
Base.minimum(d::TuringDiscreteNonParametric) = first(support(d))
480-
Base.maximum(d::TuringDiscreteNonParametric) = last(support(d))
481-
Distributions.insupport(d::TuringDiscreteNonParametric, x::Real) =
482-
length(searchsorted(support(d), x)) > 0
483-
484-
Distributions.mean(d::TuringDiscreteNonParametric) = dot(probs(d), support(d))
485-
486-
function Distributions.var(d::TuringDiscreteNonParametric{T}) where T
487-
m = mean(d)
488-
x = support(d)
489-
p = probs(d)
490-
k = length(x)
491-
σ² = zero(T)
492-
for i in 1:k
493-
@inbounds σ² += abs2(x[i] - m) * p[i]
494-
end
495-
σ²
496-
end
497-
498-
Distributions.mode(d::TuringDiscreteNonParametric) = support(d)[argmax(probs(d))]
499-
function Distributions.modes(d::TuringDiscreteNonParametric{T,P}) where {T,P}
500-
x = support(d)
501-
p = probs(d)
502-
k = length(x)
503-
mds = T[]
504-
max_p = zero(P)
505-
@inbounds for i in 1:k
506-
pi = p[i]
507-
xi = x[i]
508-
if pi > max_p
509-
max_p = pi
510-
mds = [xi]
511-
elseif pi == max_p
512-
push!(mds, xi)
513-
end
514-
end
515-
mds
516-
end
517-
518-
function Distributions.Categorical(p::TrackedVector; check_args = true)
519-
return TuringDiscreteNonParametric(1:length(p), p, check_args = check_args)
520-
end

0 commit comments

Comments
 (0)