Skip to content

Commit cd5d4cc

Browse files
sethaxendevmotion
andauthored
Add MvLogitNormal (#1774)
* Create MvLogitNormal * Add MvLogitNormal to docs * Simplify constructors * Fix conversions * Rearrange code * Fix computation of -Inf * Add meanform and canonform * Add back type constructor * Add MvLogitNormal tests * Update and test show method * Fix testset name * Fix for older Julia versions * Restrict testing of `show` method to newer versions * Add kldivergence tests * Improve documentation * Remove constructor with type and AbstractMvNormal params * Update show method * Update docstring * Remove reference to Dirichlet * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: David Widmann <[email protected]>
1 parent b21e515 commit cd5d4cc

File tree

6 files changed

+302
-0
lines changed

6 files changed

+302
-0
lines changed

docs/src/multivariate.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Multinomial
5555
Distributions.AbstractMvNormal
5656
MvNormal
5757
MvNormalCanon
58+
MvLogitNormal
5859
MvLogNormal
5960
Dirichlet
6061
Product

src/Distributions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ export
122122
Logistic,
123123
LogNormal,
124124
LogUniform,
125+
MvLogitNormal,
125126
LogitNormal,
126127
MatrixBeta,
127128
MatrixFDist,

src/multivariate/mvlogitnormal.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
MvLogitNormal{<:AbstractMvNormal}
3+
4+
The [multivariate logit-normal distribution](https://en.wikipedia.org/wiki/Logit-normal_distribution#Multivariate_generalization)
5+
is a multivariate generalization of [`LogitNormal`](@ref) capable of handling correlations
6+
between variables.
7+
8+
If ``\\mathbf{y} \\sim \\mathrm{MvNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma})`` is a
9+
length ``d-1`` vector, then
10+
```math
11+
\\mathbf{x} = \\operatorname{softmax}\\left(\\begin{bmatrix}\\mathbf{y} \\\\ 0 \\end{bmatrix}\\right) \\sim \\mathrm{MvLogitNormal}(\\boldsymbol{\\mu}, \\boldsymbol{\\Sigma})
12+
```
13+
is a length ``d`` probability vector.
14+
15+
```julia
16+
MvLogitNormal(μ, Σ) # MvLogitNormal with y ~ MvNormal(μ, Σ)
17+
MvLogitNormal(MvNormal(μ, Σ)) # same as above
18+
MvLogitNormal(MvNormalCanon(μ, J)) # MvLogitNormal with y ~ MvNormalCanon(μ, J)
19+
```
20+
21+
# Fields
22+
23+
- `normal::AbstractMvNormal`: contains the ``d-1``-dimensional distribution of ``y``
24+
"""
25+
struct MvLogitNormal{D<:AbstractMvNormal} <: ContinuousMultivariateDistribution
26+
normal::D
27+
MvLogitNormal{D}(normal::D) where {D<:AbstractMvNormal} = new{D}(normal)
28+
end
29+
MvLogitNormal(d::AbstractMvNormal) = MvLogitNormal{typeof(d)}(d)
30+
MvLogitNormal(args...) = MvLogitNormal(MvNormal(args...))
31+
32+
function Base.show(io::IO, d::MvLogitNormal; indent::String=" ")
33+
print(io, distrname(d))
34+
println(io, "(")
35+
normstr = strip(sprint(show, d.normal; context=IOContext(io)))
36+
normstr = replace(normstr, "\n" => "\n$indent")
37+
print(io, indent)
38+
println(io, normstr)
39+
println(io, ")")
40+
end
41+
42+
# Conversions
43+
44+
function convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal) where {D}
45+
return MvLogitNormal(convert(D, d.normal))
46+
end
47+
Base.convert(::Type{MvLogitNormal{D}}, d::MvLogitNormal{D}) where {D} = d
48+
49+
meanform(d::MvLogitNormal{<:MvNormalCanon}) = MvLogitNormal(meanform(d.normal))
50+
canonform(d::MvLogitNormal{<:MvNormal}) = MvLogitNormal(canonform(d.normal))
51+
52+
# Properties
53+
54+
length(d::MvLogitNormal) = length(d.normal) + 1
55+
Base.eltype(::Type{<:MvLogitNormal{D}}) where {D} = eltype(D)
56+
Base.eltype(d::MvLogitNormal) = eltype(d.normal)
57+
params(d::MvLogitNormal) = params(d.normal)
58+
@inline partype(d::MvLogitNormal) = partype(d.normal)
59+
60+
location(d::MvLogitNormal) = mean(d.normal)
61+
minimum(d::MvLogitNormal) = fill(zero(eltype(d)), length(d))
62+
maximum(d::MvLogitNormal) = fill(oneunit(eltype(d)), length(d))
63+
64+
function insupport(d::MvLogitNormal, x::AbstractVector{<:Real})
65+
return length(d) == length(x) && all((0), x) && sum(x) 1
66+
end
67+
68+
# Evaluation
69+
70+
function _logpdf(d::MvLogitNormal, x::AbstractVector{<:Real})
71+
if !insupport(d, x)
72+
return oftype(logpdf(d.normal, _inv_softmax1(abs.(x))), -Inf)
73+
else
74+
return logpdf(d.normal, _inv_softmax1(x)) - sum(log, x)
75+
end
76+
end
77+
78+
function gradlogpdf(d::MvLogitNormal, x::AbstractVector{<:Real})
79+
y = _inv_softmax1(x)
80+
∂y = gradlogpdf(d.normal, y)
81+
∂x = (vcat(∂y, -sum(∂y)) .- 1) ./ x
82+
return ∂x
83+
end
84+
85+
# Statistics
86+
87+
kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.normal)
88+
89+
# Sampling
90+
91+
function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real})
92+
y = @views _drop1(x)
93+
rand!(rng, d.normal, y)
94+
_softmax1!(x, y)
95+
return x
96+
end
97+
98+
# Fitting
99+
100+
function fit_mle(::Type{MvLogitNormal{D}}, x::AbstractMatrix{<:Real}; kwargs...) where {D}
101+
y = similar(x, size(x, 1) - 1, size(x, 2))
102+
map(_inv_softmax1!, eachcol(y), eachcol(x))
103+
normal = fit_mle(D, y; kwargs...)
104+
return MvLogitNormal(normal)
105+
end
106+
function fit_mle(::Type{MvLogitNormal}, x::AbstractMatrix{<:Real}; kwargs...)
107+
return fit_mle(MvLogitNormal{MvNormal}, x; kwargs...)
108+
end
109+
110+
# Utility
111+
112+
function _softmax1!(x::AbstractVector, y::AbstractVector)
113+
u = max(0, maximum(y))
114+
_drop1(x) .= exp.(y .- u)
115+
x[end] = exp(-u)
116+
LinearAlgebra.normalize!(x, 1)
117+
return x
118+
end
119+
function _softmax1!(x::AbstractMatrix, y::AbstractMatrix)
120+
map(_softmax1!, eachcol(x), eachcol(y))
121+
return x
122+
end
123+
124+
_drop1(x::AbstractVector) = @views x[firstindex(x, 1):(end - 1)]
125+
_drop1(x::AbstractMatrix) = @views x[firstindex(x, 1):(end - 1), :]
126+
127+
_last1(x::AbstractVector) = x[end]
128+
_last1(x::AbstractMatrix) = @views x[end, :]
129+
130+
function _inv_softmax1!(y::AbstractVecOrMat, x::AbstractVecOrMat)
131+
x₋ = _drop1(x)
132+
xd = _last1(x)
133+
@. y = log(x₋) - log(xd)
134+
return y
135+
end
136+
function _inv_softmax1(x::AbstractVecOrMat)
137+
y = similar(_drop1(x))
138+
_inv_softmax1!(y, x)
139+
return y
140+
end

src/multivariates.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ for fname in ["dirichlet.jl",
115115
"jointorderstatistics.jl",
116116
"mvnormal.jl",
117117
"mvnormalcanon.jl",
118+
"mvlogitnormal.jl",
118119
"mvlognormal.jl",
119120
"mvtdist.jl",
120121
"product.jl", # deprecated

test/multivariate/mvlogitnormal.jl

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Tests on Multivariate Logit-Normal distributions
2+
using Distributions
3+
using ForwardDiff
4+
using LinearAlgebra
5+
using Random
6+
using Test
7+
8+
####### Core testing procedure
9+
10+
function test_mvlogitnormal(d::MvLogitNormal; nsamples::Int=10^6)
11+
@test d.normal isa AbstractMvNormal
12+
dnorm = d.normal
13+
14+
@testset "properties" begin
15+
@test length(d) == length(dnorm) + 1
16+
@test params(d) == params(dnorm)
17+
@test partype(d) == partype(dnorm)
18+
@test eltype(d) == eltype(dnorm)
19+
@test eltype(typeof(d)) == eltype(typeof(dnorm))
20+
@test location(d) == mean(dnorm)
21+
@test minimum(d) == fill(0, length(d))
22+
@test maximum(d) == fill(1, length(d))
23+
@test insupport(d, normalize(rand(length(d)), 1))
24+
@test !insupport(d, normalize(rand(length(d) + 1), 1))
25+
@test !insupport(d, rand(length(d)))
26+
x = rand(length(d) - 1)
27+
x = vcat(x, -sum(x))
28+
@test !insupport(d, x)
29+
end
30+
31+
@testset "conversions" begin
32+
@test convert(typeof(d), d) === d
33+
T = partype(d) <: Float64 ? Float32 : Float64
34+
if dnorm isa MvNormal
35+
@test convert(MvLogitNormal{MvNormal{T}}, d).normal ==
36+
convert(MvNormal{T}, dnorm)
37+
@test partype(convert(MvLogitNormal{MvNormal{T}}, d)) <: T
38+
@test canonform(d) isa MvLogitNormal{<:MvNormalCanon}
39+
@test canonform(d).normal == canonform(dnorm)
40+
elseif dnorm isa MvNormalCanon
41+
@test convert(MvLogitNormal{MvNormalCanon{T}}, d).normal ==
42+
convert(MvNormalCanon{T}, dnorm)
43+
@test partype(convert(MvLogitNormal{MvNormalCanon{T}}, d)) <: T
44+
@test meanform(d) isa MvLogitNormal{<:MvNormal}
45+
@test meanform(d).normal == meanform(dnorm)
46+
end
47+
end
48+
49+
@testset "sampling" begin
50+
X = rand(d, nsamples)
51+
Y = @views log.(X[1:(end - 1), :]) .- log.(X[end, :]')
52+
Ymean = vec(mean(Y; dims=2))
53+
Ycov = cov(Y; dims=2)
54+
for i in 1:(length(d) - 1)
55+
@test isapprox(
56+
Ymean[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8
57+
)
58+
end
59+
for i in 1:(length(d) - 1), j in 1:(length(d) - 1)
60+
@test isapprox(
61+
Ycov[i, j],
62+
cov(dnorm)[i, j],
63+
atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20,
64+
)
65+
end
66+
end
67+
68+
@testset "fitting" begin
69+
X = rand(d, nsamples)
70+
dfit = fit_mle(MvLogitNormal, X)
71+
dfit_norm = dfit.normal
72+
for i in 1:(length(d) - 1)
73+
@test isapprox(
74+
mean(dfit_norm)[i], mean(dnorm)[i], atol=sqrt(var(dnorm)[i] / nsamples) * 8
75+
)
76+
end
77+
for i in 1:(length(d) - 1), j in 1:(length(d) - 1)
78+
@test isapprox(
79+
cov(dfit_norm)[i, j],
80+
cov(dnorm)[i, j],
81+
atol=sqrt(prod(var(dnorm)[[i, j]]) / nsamples) * 20,
82+
)
83+
end
84+
@test fit_mle(MvLogitNormal{IsoNormal}, X) isa MvLogitNormal{<:IsoNormal}
85+
end
86+
87+
@testset "evaluation" begin
88+
X = rand(d, nsamples)
89+
for i in 1:min(100, nsamples)
90+
@test @inferred(logpdf(d, X[:, i])) log(pdf(d, X[:, i]))
91+
if dnorm isa MvNormal
92+
@test @inferred(gradlogpdf(d, X[:, i]))
93+
ForwardDiff.gradient(x -> logpdf(d, x), X[:, i])
94+
end
95+
end
96+
@test logpdf(d, X) log.(pdf(d, X))
97+
@test isequal(logpdf(d, zeros(length(d))), -Inf)
98+
@test isequal(logpdf(d, ones(length(d))), -Inf)
99+
@test isequal(pdf(d, zeros(length(d))), 0)
100+
@test isequal(pdf(d, ones(length(d))), 0)
101+
end
102+
end
103+
104+
@testset "Results MvLogitNormal consistent with univariate LogitNormal" begin
105+
μ = randn()
106+
σ = rand()
107+
d = MvLogitNormal([μ], fill^2, 1, 1))
108+
duni = LogitNormal(μ, σ)
109+
@test location(d) [location(duni)]
110+
x = normalize(rand(2), 1)
111+
@test logpdf(d, x) logpdf(duni, x[1])
112+
@test pdf(d, x) pdf(duni, x[1])
113+
@test (Random.seed!(9274); rand(d)[1]) (Random.seed!(9274); rand(duni))
114+
end
115+
116+
###### General Testing
117+
118+
@testset "MvLogitNormal tests" begin
119+
mvnorm_params = [
120+
(randn(5), I * rand()),
121+
(randn(4), Diagonal(rand(4))),
122+
(Diagonal(rand(6)),),
123+
(randn(5), exp(Symmetric(randn(5, 5)))),
124+
(exp(Symmetric(randn(5, 5))),),
125+
]
126+
@testset "wraps MvNormal" begin
127+
@testset "$(typeof(prms))" for prms in mvnorm_params
128+
d = MvLogitNormal(prms...)
129+
@test d == MvLogitNormal(MvNormal(prms...))
130+
test_mvlogitnormal(d; nsamples=10^4)
131+
end
132+
end
133+
@testset "wraps MvNormalCanon" begin
134+
@testset "$(typeof(prms))" for prms in mvnorm_params
135+
d = MvLogitNormal(MvNormalCanon(prms...))
136+
test_mvlogitnormal(d; nsamples=10^4)
137+
end
138+
end
139+
140+
@testset "kldivergence" begin
141+
d1 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5))))
142+
d2 = MvLogitNormal(randn(5), exp(Symmetric(randn(5, 5))))
143+
@test kldivergence(d1, d2) kldivergence(d1.normal, d2.normal)
144+
end
145+
146+
VERSION v"1.8" && @testset "show" begin
147+
d = MvLogitNormal([1.0, 2.0, 3.0], Diagonal([4.0, 5.0, 6.0]))
148+
@test sprint(show, d) === """
149+
MvLogitNormal{DiagNormal}(
150+
DiagNormal(
151+
dim: 3
152+
μ: [1.0, 2.0, 3.0]
153+
Σ: [4.0 0.0 0.0; 0.0 5.0 0.0; 0.0 0.0 6.0]
154+
)
155+
)
156+
"""
157+
end
158+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ const tests = [
2626
"univariate/continuous/uniform",
2727
"univariate/continuous/lognormal",
2828
"multivariate/mvnormal",
29+
"multivariate/mvlogitnormal",
2930
"multivariate/mvlognormal",
3031
"types", # extra file compared to /src
3132
"utils",

0 commit comments

Comments
 (0)