You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.
45
+
46
+
This is a wrapper around `Optimisers.update!` to provide some indirection.
47
+
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
48
+
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
49
+
Instead, the return values should be used.
50
+
51
+
# Arguments
52
+
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
53
+
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
54
+
- `params`: Current set of parameters to be updated.
55
+
- `restructure`: Callable for restructuring the varitional distribution from `params`.
56
+
- `grad`: Gradient to be used by the update rule of `opt_st`.
function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
64
74
@unpack location, scale, dist = q
65
-
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) -first(logdet(scale))
75
+
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) -logdet(scale)
66
76
end
67
77
68
78
function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
69
79
@unpack location, scale, dist = q
70
-
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) -first(logdet(scale))
80
+
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) -logdet(scale)
71
81
end
72
82
73
83
function Distributions.rand(q::MvLocationScale)
@@ -128,14 +138,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix.
128
138
functionFullRankGaussian(
129
139
μ::AbstractVector{T},
130
140
L::LinearAlgebra.AbstractTriangular{T};
131
-
check_args::Bool=true
141
+
scale_eps::T=sqrt(eps(T))
132
142
) where {T <:Real}
133
-
@assertminimum(diag(L)) >eps(eltype(L)) "Scale must be positive definite"
134
-
if check_args && (minimum(diag(L)) <sqrt(eps(eltype(L))))
135
-
@warn"Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
136
-
end
143
+
@assertminimum(diag(L)) ≥sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
137
144
q_base =Normal{T}(zero(T), one(T))
138
-
MvLocationScale(μ, L, q_base)
145
+
MvLocationScale(μ, L, q_base, scale_eps)
139
146
end
140
147
141
148
"""
@@ -153,12 +160,25 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix
153
160
functionMeanFieldGaussian(
154
161
μ::AbstractVector{T},
155
162
L::Diagonal{T};
156
-
check_args::Bool=true
163
+
scale_eps::T=sqrt(eps(T)),
157
164
) where {T <:Real}
158
-
@assertminimum(diag(L)) >eps(eltype(L)) "Scale must be a Cholesky factor"
159
-
if check_args && (minimum(diag(L)) <sqrt(eps(eltype(L))))
160
-
@warn"Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
161
-
end
165
+
@assertminimum(diag(L)) ≥sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
162
166
q_base =Normal{T}(zero(T), one(T))
163
-
MvLocationScale(μ, L, q_base)
167
+
MvLocationScale(μ, L, q_base, scale_eps)
168
+
end
169
+
170
+
functionupdate_variational_params!(
171
+
::Type{<:MvLocationScale}, opt_st, params, restructure, grad
0 commit comments