Skip to content
9 changes: 9 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ Base.size(s::Sampleable)
Base.size(s::Sampleable{Univariate}) = ()
Base.size(s::Sampleable{Multivariate}) = (length(s),)

"""
axes(s::Sampleable{<:ArrayLikeVariate})

Return a tuple of valid indices for a sample from `s`. This can e.g. be used to construct
an empty sample array using `similar(Array{eltype(s)}, axes(s))`. This can be overloaded
for a given distribution to return e.g. axes that define custom array types.
"""
Base.axes(s::Sampleable{<:ArrayLikeVariate})

"""
eltype(::Type{Sampleable})

Expand Down
16 changes: 10 additions & 6 deletions src/genericrand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ rand(rng::AbstractRNG, s::Sampleable, dim1::Int, moredims::Int...) =

# default fallback (redefined for univariate distributions)
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate})
return rand!(rng, s, Array{eltype(s)}(undef, size(s)))
out = similar(Array{eltype(s)}, axes(s))
rand!(rng, s, out)
return out
end

# multiple samples
Expand All @@ -37,15 +39,17 @@ end
function rand(
rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims,
)
sz = size(s)
sax = axes(s)
ax = map(Base.OneTo, dims)
out = [Array{eltype(s)}(undef, sz) for _ in Iterators.product(ax...)]
out = [similar(Array{eltype(s)}, sax) for _ in Iterators.product(ax...)]
return rand!(rng, sampler(s), out, false)
end

# these are workarounds for sampleables that incorrectly base `eltype` on the parameters
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous})
return rand!(rng, sampler(s), Array{float(eltype(s))}(undef, size(s)))
out = similar(Array{float(eltype(s))}, axes(s))
rand!(rng, sampler(s), out)
return out
end
function rand(rng::AbstractRNG, s::Sampleable{Univariate,Continuous}, dims::Dims)
out = Array{float(eltype(s))}(undef, dims)
Expand All @@ -54,9 +58,9 @@ end
function rand(
rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous}, dims::Dims,
)
sz = size(s)
sax = axes(s)
ax = map(Base.OneTo, dims)
out = [Array{float(eltype(s))}(undef, sz) for _ in Iterators.product(ax...)]
out = [similar(Array{float(eltype(s))}, sax) for _ in Iterators.product(ax...)]
return rand!(rng, sampler(s), out, false)
end

Expand Down
1 change: 1 addition & 0 deletions src/matrix/inversewishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ insupport(::Type{InverseWishart}, X::Matrix) = isposdef(X)
insupport(d::InverseWishart, X::Matrix) = size(X) == size(d) && isposdef(X)

size(d::InverseWishart) = size(d.Ψ)
Base.axes(d::InverseWishart) = axes(d.Ψ)
rank(d::InverseWishart) = rank(d.Ψ)

params(d::InverseWishart) = (d.df, d.Ψ)
Expand Down
2 changes: 2 additions & 0 deletions src/matrix/matrixfdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ end

size(d::MatrixFDist) = size(d.W)

Base.axes(d::MatrixFDist) = axes(d.W)

rank(d::MatrixFDist) = size(d, 1)

insupport(d::MatrixFDist, Σ::AbstractMatrix) = isreal(Σ) && size(Σ) == size(d) && isposdef(Σ)
Expand Down
2 changes: 2 additions & 0 deletions src/matrix/matrixnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ end

size(d::MatrixNormal) = size(d.M)

Base.axes(d::MatrixNormal) = axes(d.M)

rank(d::MatrixNormal) = minimum( size(d) )

insupport(d::MatrixNormal, X::AbstractMatrix) = isreal(X) && size(X) == size(d)
Expand Down
2 changes: 2 additions & 0 deletions src/matrix/matrixtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ end

size(d::MatrixTDist) = size(d.M)

Base.axes(d::MatrixTDist) = axes(d.M)

rank(d::MatrixTDist) = minimum( size(d) )

insupport(d::MatrixTDist, X::Matrix) = isreal(X) && size(X) == size(d)
Expand Down
2 changes: 2 additions & 0 deletions src/matrix/wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ end

size(d::Wishart) = size(d.S)

Base.axes(d::Wishart) = axes(d.S)

rank(d::Wishart) = d.rank
params(d::Wishart) = (d.df, d.S)
@inline partype(d::Wishart{T}) where {T<:Real} = T
Expand Down
54 changes: 35 additions & 19 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ end
The length of each sample (only for `Multivariate`).
"""
length(d::MultivariateMixture) = length(d.components[1])
Base.axes(d::MultivariateMixture) = axes(d.components[1])
size(d::MatrixvariateMixture) = size(d.components[1])

ncomponents(d::MixtureModel) = length(d.components)
Expand All @@ -179,13 +180,11 @@ function mean(d::UnivariateMixture)
end

function mean(d::MultivariateMixture)
K = ncomponents(d)
p = probs(d)
m = zeros(length(d))
for i = 1:K
pi = p[i]
m = similar(p, float(eltype(d)), axes(d))
fill!(m, 0.0)
for (c, pi) in zip(components(d), p)
if pi > 0.0
c = component(d, i)
axpy!(pi, mean(c), m)
end
end
Expand Down Expand Up @@ -222,33 +221,50 @@ function var(d::UnivariateMixture)
end

function var(d::MultivariateMixture)
return diag(cov(d))
p = probs(d)
ax = axes(d)
m = similar(p, ax)
md = similar(p, ax)
v = similar(p, ax)
fill!(m, 0.0)
fill!(v, 0.0)

for (c, pi) in zip(components(d), p)
if pi > 0.0
axpy!(pi, mean(c), m)
axpy!(pi, var(c), v)
end
end
for (c, pi) in zip(components(d), p)
if pi > 0.0
md .= mean(c) .- m
@. v += pi * abs2(md)
end
end
return v
end

function cov(d::MultivariateMixture)
K = ncomponents(d)
p = probs(d)
m = zeros(length(d))
md = zeros(length(d))
V = zeros(length(d),length(d))

for i = 1:K
pi = p[i]
ax = axes(d)
m = similar(p, ax)
md = similar(p, ax)
V = similar(p, (ax[1], ax[1]))
fill!(m, 0.0)
fill!(V, 0.0)

for (c, pi) in zip(components(d), p)
if pi > 0.0
c = component(d, i)
axpy!(pi, mean(c), m)
axpy!(pi, cov(c), V)
end
end
for i = 1:K
pi = p[i]
for (c, pi) in zip(components(d), p)
if pi > 0.0
c = component(d, i)
md .= mean(c) .- m
BLAS.syr!('U', Float64(pi), md, V)
@. V += pi * md * md'
end
end
LinearAlgebra.copytri!(V, 'U')
return V
end

Expand Down
21 changes: 8 additions & 13 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,16 @@ function Dirichlet(d::Integer, alpha::Real; check_args::Bool=true)
return Dirichlet{typeof(alpha)}(Fill(alpha, d); check_args=false)
end

Base.axes(d::Dirichlet) = axes(d.alpha)

struct DirichletCanon{T<:Real,Ts<:AbstractVector{T}}
alpha::Ts
end

length(d::DirichletCanon) = length(d.alpha)

Base.axes(d::DirichletCanon) = axes(d.alpha)

Base.eltype(::Type{<:Dirichlet{T}}) where {T} = T

#### Conversions
Expand Down Expand Up @@ -85,23 +89,14 @@ end

function cov(d::Dirichlet)
α = d.alpha
ax = axes(α, 1)
α0 = d.alpha0
c = inv(α0^2 * (α0 + 1))

T = typeof(zero(eltype(α))^2 * c)
k = length(α)
C = Matrix{T}(undef, k, k)
for j = 1:k
αj = α[j]
αjc = αj * c
for i in 1:(j-1)
C[i,j] = C[j,i]
end
C[j,j] = (α0 - αj) * αjc
for i in (j+1):k
C[i,j] = - α[i] * αjc
end
end
C = similar(α, T, (ax, ax))
@. C = -α * α' * c
C[diagind(C)] .+= (α0 * c) .* α

return C
end
Expand Down
7 changes: 3 additions & 4 deletions src/multivariate/dirichletmultinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Base.show(io::IO, d::DirichletMultinomial) = show(io, d, (:n, :α,))

# Parameters
ncategories(d::DirichletMultinomial) = length(d.α)
Base.axes(d::DirichletMultinomial) = axes(d.α)
length(d::DirichletMultinomial) = ncategories(d)
ntrials(d::DirichletMultinomial) = d.n
params(d::DirichletMultinomial) = (d.n, d.α)
Expand All @@ -59,11 +60,9 @@ params(d::DirichletMultinomial) = (d.n, d.α)
# Statistics
mean(d::DirichletMultinomial) = d.α .* (d.n / d.α0)
function var(d::DirichletMultinomial{T}) where T <: Real
v = fill(d.n * (d.n + d.α0) / (1 + d.α0), length(d))
v0 = d.n * (d.n + d.α0) / (1 + d.α0)
p = d.α / d.α0
for i in eachindex(v)
v[i] *= p[i] * (1 - p[i])
end
v = @. v0 * p * (1 - p)
v
end
function cov(d::DirichletMultinomial{<:Real})
Expand Down
5 changes: 3 additions & 2 deletions src/multivariate/jointorderstatistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function _are_ranks_valid(ranks::AbstractRange, n)
end

length(d::JointOrderStatistics) = length(d.ranks)
Base.axes(d::JointOrderStatistics) = axes(d.ranks)
function insupport(d::JointOrderStatistics, x::AbstractVector)
length(d) == length(x) || return false
xi, state = iterate(x) # at least one element!
Expand All @@ -83,8 +84,8 @@ function insupport(d::JointOrderStatistics, x::AbstractVector)
end
return true
end
minimum(d::JointOrderStatistics) = Fill(minimum(d.dist), length(d))
maximum(d::JointOrderStatistics) = Fill(maximum(d.dist), length(d))
minimum(d::JointOrderStatistics) = Fill(minimum(d.dist), axes(d))
maximum(d::JointOrderStatistics) = Fill(maximum(d.dist), axes(d))

params(d::JointOrderStatistics) = tuple(params(d.dist)..., d.n, d.ranks)
partype(d::JointOrderStatistics) = partype(d.dist)
Expand Down
30 changes: 7 additions & 23 deletions src/multivariate/multinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ end

ncategories(d::Multinomial) = length(d.p)
length(d::Multinomial) = ncategories(d)
Base.axes(d::Multinomial) = axes(d.p)
probs(d::Multinomial) = d.p
ntrials(d::Multinomial) = d.n

Expand All @@ -63,37 +64,20 @@ mean(d::Multinomial) = d.n .* d.p

function var(d::Multinomial{T}) where T<:Real
p = probs(d)
k = length(p)
n = ntrials(d)

v = Vector{T}(undef, k)
for i = 1:k
p_i = p[i]
v[i] = n * p_i * (1 - p_i)
end
v = @. n * p * (1 - p)
v
end

function cov(d::Multinomial{T}) where T<:Real
p = probs(d)
k = length(p)
ax = axes(p, 1)
n = ntrials(d)

C = Matrix{T}(undef, k, k)
for j = 1:k
pj = p[j]
for i = 1:j-1
C[i,j] = - n * p[i] * pj
end

C[j,j] = n * pj * (1-pj)
end

for j = 1:k-1
for i = j+1:k
C[i,j] = C[j,i]
end
end
C = similar(p, T, (ax, ax))
@. C = -n * p * p'
C[diagind(C)] .+= n .* p
C
end

Expand Down Expand Up @@ -168,7 +152,7 @@ end
_rand!(rng::AbstractRNG, d::Multinomial, x::AbstractVector{<:Real}) =
multinom_rand!(rng, ntrials(d), probs(d), x)

sampler(d::Multinomial) = MultinomialSampler(ntrials(d), probs(d))
sampler(d::Multinomial) = MultinomialSampler(ntrials(d), convert(Vector, probs(d)))


## Fit model
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ function convert(::Type{MvLogNormal{T}}, pars...) where T<:Real
end

length(d::MvLogNormal) = length(d.normal)
Base.axes(d::MvLogNormal) = axes(d.normal)
params(d::MvLogNormal) = params(d.normal)
@inline partype(d::MvLogNormal{T}) where {T<:Real} = T

Expand Down
7 changes: 4 additions & 3 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ abstract type AbstractMvNormal <: ContinuousMultivariateDistribution end
insupport(d::AbstractMvNormal, x::AbstractVector) =
length(d) == length(x) && all(isfinite, x)

minimum(d::AbstractMvNormal) = fill(eltype(d)(-Inf), length(d))
maximum(d::AbstractMvNormal) = fill(eltype(d)(Inf), length(d))
minimum(d::AbstractMvNormal) = fill(eltype(d)(-Inf), axes(d))
maximum(d::AbstractMvNormal) = fill(eltype(d)(Inf), axes(d))
mode(d::AbstractMvNormal) = mean(d)
modes(d::AbstractMvNormal) = [mean(d)]

Expand Down Expand Up @@ -211,7 +211,7 @@ end

Construct a multivariate normal distribution with zero mean and covariance matrix `Σ`.
"""
MvNormal(Σ::AbstractMatrix{<:Real}) = MvNormal(Zeros{eltype(Σ)}(size(Σ, 1)), Σ)
MvNormal(Σ::AbstractMatrix{<:Real}) = MvNormal(Zeros{eltype(Σ)}(axes(Σ, 1)), Σ)

# deprecated constructors with standard deviations
Base.@deprecate MvNormal(μ::AbstractVector{<:Real}, σ::AbstractVector{<:Real}) MvNormal(μ, LinearAlgebra.Diagonal(map(abs2, σ)))
Expand Down Expand Up @@ -247,6 +247,7 @@ Base.show(io::IO, d::MvNormal) =
### Basic statistics

length(d::MvNormal) = length(d.μ)
Base.axes(d::MvNormal) = axes(d.μ)
mean(d::MvNormal) = d.μ
params(d::MvNormal) = (d.μ, d.Σ)
@inline partype(d::MvNormal{T}) where {T<:Real} = T
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ canonform(d::MvNormal{T,C,Zeros{T}}) where {C, T<:Real} = MvNormalCanon(inv(d.Σ
### Basic statistics

length(d::MvNormalCanon) = length(d.μ)
Base.axes(d::MvNormalCanon) = axes(d.h)
mean(d::MvNormalCanon) = convert(Vector{eltype(d.μ)}, d.μ)
params(d::MvNormalCanon) = (d.μ, d.h, d.J)
@inline partype(d::MvNormalCanon{T}) where {T<:Real} = T
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ mvtdist(df::Real, Σ::Matrix{<:Real}) = MvTDist(df, Σ)
# Basic statistics

length(d::GenericMvTDist) = d.dim
Base.axes(d::GenericMvTDist) = axes(d.μ)

mean(d::GenericMvTDist) = d.df>1 ? d.μ : NaN
mode(d::GenericMvTDist) = d.μ
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ function Product(v::V) where {S<:ValueSupport,T<:UnivariateDistribution{S},V<:Ab
end

length(d::Product) = length(d.v)
Base.axes(d::Product) = axes(d.v)
function Base.eltype(::Type{<:Product{S,T}}) where {S<:ValueSupport,
T<:UnivariateDistribution{S}}
return eltype(T)
Expand Down
Loading
Loading