| 
 | 1 | + | 
 | 2 | +"""  | 
 | 3 | +    MvLocationScale(location, scale, dist) <: ContinuousMultivariateDistribution  | 
 | 4 | +
  | 
 | 5 | +The location scale variational family broadly represents various variational  | 
 | 6 | +families using `location` and `scale` variational parameters.  | 
 | 7 | +
  | 
 | 8 | +It generally represents any distribution for which the sampling path can be  | 
 | 9 | +represented as follows:  | 
 | 10 | +```julia  | 
 | 11 | +  d = length(location)  | 
 | 12 | +  u = rand(dist, d)  | 
 | 13 | +  z = scale*u + location  | 
 | 14 | +```  | 
 | 15 | +"""  | 
 | 16 | +struct MvLocationScale{  | 
 | 17 | +    S, D <: ContinuousDistribution, L  | 
 | 18 | +} <: ContinuousMultivariateDistribution  | 
 | 19 | +    location::L  | 
 | 20 | +    scale   ::S  | 
 | 21 | +    dist    ::D  | 
 | 22 | +end  | 
 | 23 | + | 
 | 24 | +Functors.@functor MvLocationScale (location, scale)  | 
 | 25 | + | 
 | 26 | +# Specialization of `Optimisers.destructure` for mean-field location-scale families.  | 
 | 27 | +# These are necessary because we only want to extract the diagonal elements of   | 
 | 28 | +# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD  | 
 | 29 | +# is very inefficient.  | 
 | 30 | +# begin  | 
 | 31 | +struct RestructureMeanField{S <: Diagonal, D, L}  | 
 | 32 | +    q::MvLocationScale{S, D, L}  | 
 | 33 | +end  | 
 | 34 | + | 
 | 35 | +function (re::RestructureMeanField)(flat::AbstractVector)  | 
 | 36 | +    n_dims   = div(length(flat), 2)  | 
 | 37 | +    location = first(flat, n_dims)  | 
 | 38 | +    scale    = Diagonal(last(flat, n_dims))  | 
 | 39 | +    MvLocationScale(location, scale, re.q.dist)  | 
 | 40 | +end  | 
 | 41 | + | 
 | 42 | +function Optimisers.destructure(  | 
 | 43 | +    q::MvLocationScale{<:Diagonal, D, L}  | 
 | 44 | +) where {D, L}  | 
 | 45 | +    @unpack location, scale, dist = q  | 
 | 46 | +    flat   = vcat(location, diag(scale))  | 
 | 47 | +    flat, RestructureMeanField(q)  | 
 | 48 | +end  | 
 | 49 | +# end  | 
 | 50 | + | 
 | 51 | +Base.length(q::MvLocationScale) = length(q.location)  | 
 | 52 | + | 
 | 53 | +Base.size(q::MvLocationScale) = size(q.location)  | 
 | 54 | + | 
 | 55 | +Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)  | 
 | 56 | + | 
 | 57 | +function StatsBase.entropy(q::MvLocationScale)  | 
 | 58 | +    @unpack  location, scale, dist = q  | 
 | 59 | +    n_dims = length(location)  | 
 | 60 | +    n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale))  | 
 | 61 | +end  | 
 | 62 | + | 
 | 63 | +function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})  | 
 | 64 | +    @unpack location, scale, dist = q  | 
 | 65 | +    sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))  | 
 | 66 | +end  | 
 | 67 | + | 
 | 68 | +function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})  | 
 | 69 | +    @unpack location, scale, dist = q  | 
 | 70 | +    sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))  | 
 | 71 | +end  | 
 | 72 | + | 
 | 73 | +function Distributions.rand(q::MvLocationScale)  | 
 | 74 | +    @unpack location, scale, dist = q  | 
 | 75 | +    n_dims = length(location)  | 
 | 76 | +    scale*rand(dist, n_dims) + location  | 
 | 77 | +end  | 
 | 78 | + | 
 | 79 | +function Distributions.rand(  | 
 | 80 | +    rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int  | 
 | 81 | +)  where {S, D, L}  | 
 | 82 | +    @unpack location, scale, dist = q  | 
 | 83 | +    n_dims = length(location)  | 
 | 84 | +    scale*rand(rng, dist, n_dims, num_samples) .+ location  | 
 | 85 | +end  | 
 | 86 | + | 
 | 87 | +# This specialization improves AD performance of the sampling path  | 
 | 88 | +function Distributions.rand(  | 
 | 89 | +    rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int  | 
 | 90 | +) where {L, D}  | 
 | 91 | +    @unpack location, scale, dist = q  | 
 | 92 | +    n_dims     = length(location)  | 
 | 93 | +    scale_diag = diag(scale)  | 
 | 94 | +    scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location  | 
 | 95 | +end  | 
 | 96 | + | 
 | 97 | +function Distributions._rand!(rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real})  | 
 | 98 | +    @unpack location, scale, dist = q  | 
 | 99 | +    rand!(rng, dist, x)  | 
 | 100 | +    x[:] = scale*x  | 
 | 101 | +    return x .+= location  | 
 | 102 | +end  | 
 | 103 | + | 
 | 104 | +Distributions.mean(q::MvLocationScale) = q.location  | 
 | 105 | + | 
 | 106 | +function Distributions.var(q::MvLocationScale)    | 
 | 107 | +    C = q.scale  | 
 | 108 | +    Diagonal(C*C')  | 
 | 109 | +end  | 
 | 110 | + | 
 | 111 | +function Distributions.cov(q::MvLocationScale)  | 
 | 112 | +    C = q.scale  | 
 | 113 | +    Hermitian(C*C')  | 
 | 114 | +end  | 
 | 115 | + | 
 | 116 | +"""  | 
 | 117 | +    FullRankGaussian(location, scale; check_args = true)  | 
 | 118 | +
  | 
 | 119 | +Construct a Gaussian variational approximation with a dense covariance matrix.  | 
 | 120 | +
  | 
 | 121 | +# Arguments  | 
 | 122 | +- `location::AbstractVector{T}`: Mean of the Gaussian.  | 
 | 123 | +- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.  | 
 | 124 | +
  | 
 | 125 | +# Keyword Arguments  | 
 | 126 | +- `check_args`: Check the conditioning of the initial scale (default: `true`).  | 
 | 127 | +"""  | 
 | 128 | +function FullRankGaussian(  | 
 | 129 | +    μ::AbstractVector{T},  | 
 | 130 | +    L::LinearAlgebra.AbstractTriangular{T};  | 
 | 131 | +    check_args::Bool = true  | 
 | 132 | +) where {T <: Real}  | 
 | 133 | +    @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"  | 
 | 134 | +    if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))  | 
 | 135 | +        @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."  | 
 | 136 | +    end  | 
 | 137 | +    q_base = Normal{T}(zero(T), one(T))  | 
 | 138 | +    MvLocationScale(μ, L, q_base)  | 
 | 139 | +end  | 
 | 140 | + | 
 | 141 | +"""  | 
 | 142 | +    MeanFieldGaussian(location, scale; check_args = true)  | 
 | 143 | +
  | 
 | 144 | +Construct a Gaussian variational approximation with a diagonal covariance matrix.  | 
 | 145 | +
  | 
 | 146 | +# Arguments  | 
 | 147 | +- `location::AbstractVector{T}`: Mean of the Gaussian.  | 
 | 148 | +- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.  | 
 | 149 | +
  | 
 | 150 | +# Keyword Arguments  | 
 | 151 | +- `check_args`: Check the conditioning of the initial scale (default: `true`).  | 
 | 152 | +"""  | 
 | 153 | +function MeanFieldGaussian(  | 
 | 154 | +    μ::AbstractVector{T},  | 
 | 155 | +    L::Diagonal{T};  | 
 | 156 | +    check_args::Bool = true  | 
 | 157 | +) where {T <: Real}  | 
 | 158 | +    @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"  | 
 | 159 | +    if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))  | 
 | 160 | +        @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."  | 
 | 161 | +    end  | 
 | 162 | +    q_base = Normal{T}(zero(T), one(T))  | 
 | 163 | +    MvLocationScale(μ, L, q_base)  | 
 | 164 | +end  | 
0 commit comments