Skip to content

Commit 6518b82

Browse files
committed
add exports new algorithms, modify vi to operate in unconstrained
1 parent 4c02f7b commit 6518b82

File tree

3 files changed

+63
-15
lines changed

3 files changed

+63
-15
lines changed

src/Turing.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ export
122122
KLMinRepGradProxDescent,
123123
KLMinRepGradDescent,
124124
KLMinScoreGradDescent,
125+
KLMinNaturalGradDescent,
126+
KLMinSqrtNaturalGradDescent,
127+
KLMinWassFwdBwd,
125128
# ADTypes
126129
AutoForwardDiff,
127130
AutoReverseDiff,

src/variational/VariationalInference.jl

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@
22
module Variational
33

44
using AdvancedVI:
5-
AdvancedVI, KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent
5+
AdvancedVI,
6+
KLMinRepGradDescent,
7+
KLMinRepGradProxDescent,
8+
KLMinScoreGradDescent,
9+
KLMinWassFwdBwd,
10+
KLMinNaturalGradDescent,
11+
KLMinSqrtNaturalGradDescent
12+
613
using ADTypes
714
using Bijectors: Bijectors
815
using Distributions
9-
using DynamicPPL
16+
using DynamicPPL: DynamicPPL
1017
using LinearAlgebra
11-
using LogDensityProblems
18+
using LogDensityProblems: LogDensityProblems
1219
using Random
1320
using ..Turing: DEFAULT_ADTYPE, PROGRESS
1421

@@ -18,7 +25,17 @@ export vi,
1825
q_fullrank_gaussian,
1926
KLMinRepGradProxDescent,
2027
KLMinRepGradDescent,
21-
KLMinScoreGradDescent
28+
KLMinScoreGradDescent,
29+
KLMinWassFwdBwd,
30+
KLMinNaturalGradDescent,
31+
KLMinSqrtNaturalGradDescent
32+
33+
requires_unconstrained_space(::AdvancedVI.AbstractVariationalAlgorithm) = false
34+
requires_unconstrained_space(::AdvancedVI.KLMinRepGradProxDescent) = true
35+
requires_unconstrained_space(::AdvancedVI.KLMinRepGradDescent) = true
36+
requires_unconstrained_space(::AdvancedVI.KLMinWassFwdBwd) = true
37+
requires_unconstrained_space(::AdvancedVI.KLMinNaturalGradDescent) = true
38+
requires_unconstrained_space(::AdvancedVI.KLMinSqrtNaturalGradDescent) = true
2239

2340
"""
2441
q_initialize_scale(
@@ -65,7 +82,7 @@ function q_initialize_scale(
6582
num_max_trials::Int=10,
6683
reduce_factor::Real=one(eltype(scale)) / 2,
6784
)
68-
prob = LogDensityFunction(model)
85+
prob = DynamicPPL.LogDensityFunction(model)
6986
ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
7087
varinfo = DynamicPPL.VarInfo(model)
7188

@@ -264,8 +281,12 @@ end
264281
265282
Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`.
266283
This is a thin wrapper around `AdvancedVI.optimize`.
284+
285+
If the chosen variational inference algorithm operates in an unconstrained space, then the provided initial variational approximation `q` must be a `Bijectors.TransformedDistribution` of an unconstrained distribution.
286+
For example, the initialization supplied by `q_meanfield_gaussian`,`q_fullrank_gaussian`, `q_locationscale`.
287+
267288
The default `algorithm`, `KLMinRepGradProxDescent` ([relevant docs](https://turinglang.org/AdvancedVI.jl/dev/klminrepgradproxdescent/)), assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`.
268-
For other variational families, refer to `AdvancedVI` to determine the best algorithm and options.
289+
For other variational families, refer the documentation of `AdvancedVI` to determine the best algorithm and other options.
269290
270291
# Arguments
271292
- `model`: The target `DynamicPPL.Model`.
@@ -294,19 +315,31 @@ function vi(
294315
algorithm::AdvancedVI.AbstractVariationalAlgorithm=KLMinRepGradProxDescent(
295316
adtype; n_samples=10
296317
),
318+
unconstrained::Bool=requires_unconstrained_space(algorithm),
297319
show_progress::Bool=PROGRESS[],
298320
kwargs...,
299321
)
300-
return AdvancedVI.optimize(
301-
rng,
302-
algorithm,
303-
max_iter,
304-
LogDensityFunction(model; adtype),
305-
q,
306-
args...;
307-
show_progress=show_progress,
308-
kwargs...,
322+
prob, q, trans = if unconstrained
323+
@assert q isa Bijectors.TransformedDistribution "The algorithm $(algorithm) operates in an unconstrained space. Therefore, the initial variational approximation is expected to be a Bijectors.TransformedDistribution of an unconstrained distribution."
324+
vi = DynamicPPL.ldf_default_varinfo(model, DynamicPPL.getlogjoint_internal)
325+
vi = DynamicPPL.set_transformed!!(vi, true)
326+
prob = DynamicPPL.LogDensityFunction(
327+
model, DynamicPPL.getlogjoint_internal, vi; adtype
328+
)
329+
prob, q.dist, q.transform
330+
else
331+
prob = DynamicPPL.LogDensityFunction(model; adtype)
332+
prob, q, nothing
333+
end
334+
q, info, state = AdvancedVI.optimize(
335+
rng, algorithm, max_iter, prob, q, args...; show_progress=show_progress, kwargs...
309336
)
337+
q = if unconstrained
338+
Bijectors.TransformedDistribution(q, trans)
339+
else
340+
q
341+
end
342+
q, info, state
310343
end
311344

312345
function vi(model::DynamicPPL.Model, q, max_iter::Int; kwargs...)

test/variational/vi.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ using Turing.Variational
5353
@testset "custom algorithm $name" for (name, algorithm) in [
5454
("KLMinRepGradProxDescent", KLMinRepGradProxDescent(adtype; n_samples=10)),
5555
("KLMinRepGradDescent", KLMinRepGradDescent(adtype; operator, n_samples=10)),
56+
("KLMinNaturalGradDescent", KLMinNaturalGradDescent(stepsize=1e-3, n_samples=10)),
57+
(
58+
"KLMinSqrtNaturalGradDescent",
59+
KLMinSqrtNaturalGradDescent(stepsize=1e-3, n_samples=10),
60+
),
61+
("KLMinWassFwdBwd", KLMinWassFwdBwd(stepsize=1e-3, n_samples=10)),
5662
]
5763
T = 1000
5864
q, _, _ = vi(
@@ -70,6 +76,12 @@ using Turing.Variational
7076
@testset "inference $name" for (name, algorithm) in [
7177
("KLMinRepGradProxDescent", KLMinRepGradProxDescent(adtype; n_samples=10)),
7278
("KLMinRepGradDescent", KLMinRepGradDescent(adtype; operator, n_samples=10)),
79+
("KLMinNaturalGradDescent", KLMinNaturalGradDescent(stepsize=1e-3, n_samples=10)),
80+
(
81+
"KLMinSqrtNaturalGradDescent",
82+
KLMinSqrtNaturalGradDescent(stepsize=1e-3, n_samples=10),
83+
),
84+
("KLMinWassFwdBwd", KLMinWassFwdBwd(stepsize=1e-3, n_samples=10)),
7385
]
7486
rng = StableRNG(0x517e1d9bf89bf94f)
7587

0 commit comments

Comments
 (0)