Skip to content

Commit ea69430

Browse files
committed
update vi interface to match [email protected]
1 parent cabe73f commit ea69430

File tree

7 files changed

+106
-192
lines changed

7 files changed

+106
-192
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Accessors = "0.1"
5555
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
5656
AdvancedMH = "0.8"
5757
AdvancedPS = "0.7"
58-
AdvancedVI = "0.4"
58+
AdvancedVI = "0.5"
5959
BangBang = "0.4.2"
6060
Bijectors = "0.14, 0.15"
6161
Compat = "4.15.0"

src/Turing.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ export
117117
q_locationscale,
118118
q_meanfield_gaussian,
119119
q_fullrank_gaussian,
120+
KLMinRepGradProxDescent,
121+
KLMinRepGradDescent,
122+
KLMinScoreGradDescent,
120123
# ADTypes
121124
AutoForwardDiff,
122125
AutoReverseDiff,

src/variational/VariationalInference.jl

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11

22
module Variational
33

4-
using DynamicPPL
4+
using AdvancedVI:
5+
AdvancedVI, KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent
56
using ADTypes
7+
using Bijectors: Bijectors
68
using Distributions
9+
using DynamicPPL
710
using LinearAlgebra
811
using LogDensityProblems
912
using Random
13+
using ..Turing: DEFAULT_ADTYPE, PROGRESS
1014

11-
import ..Turing: DEFAULT_ADTYPE, PROGRESS
12-
13-
import AdvancedVI
14-
import Bijectors
15-
16-
export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian
17-
18-
include("deprecated.jl")
15+
export vi,
16+
q_locationscale,
17+
q_meanfield_gaussian,
18+
q_fullrank_gaussian,
19+
KLMinRepGradProxDescent,
20+
KLMinRepGradDescent,
21+
KLMinScoreGradDescent
1922

2023
"""
2124
q_initialize_scale(
@@ -248,76 +251,61 @@ end
248251
"""
249252
vi(
250253
[rng::Random.AbstractRNG,]
251-
model::DynamicPPL.Model;
254+
model::DynamicPPL.Model,
252255
q,
253-
n_iterations::Int;
254-
objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
255-
10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
256-
),
256+
max_iter::Int;
257+
algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(DEFAULT_ADTYPE; n_samples=10),
257258
show_progress::Bool = Turing.PROGRESS[],
258-
optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
259-
averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
260-
operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
261-
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
262259
kwargs...
263260
)
264261
265-
Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`.
262+
Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`.
266263
This is a thin wrapper around `AdvancedVI.optimize`.
264+
The default `algorithm` assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`.
265+
For other variational families, refer to `AdvancedVI` to determine the best algorithm and options.
267266
268267
# Arguments
269268
- `model`: The target `DynamicPPL.Model`.
270269
- `q`: The initial variational approximation.
271-
- `n_iterations`: Number of optimization steps.
270+
- `max_iter`: Maximum number of steps.
272271
273272
# Keyword Arguments
274-
- `objective`: Variational objective to be optimized.
273+
- `algorithm`: Variational inference algorithm.
275274
- `show_progress`: Whether to show the progress bar.
276-
- `optimizer`: Optimization algorithm.
277-
- `averager`: Parameter averaging strategy.
278-
- `operator`: Operator applied after each optimization step.
279-
- `adtype`: Automatic differentiation backend.
275+
- `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiation the variational objective.
280276
281277
See the docs of `AdvancedVI.optimize` for additional keyword arguments.
282278
283279
# Returns
284-
- `q`: Variational distribution formed by the last iterate of the optimization run.
285-
- `q_avg`: Variational distribution formed by the averaged iterates according to `averager`.
286-
- `state`: Collection of states used for optimization. This can be used to resume from a past call to `vi`.
287-
- `info`: Information generated during the optimization run.
280+
- `q`: Output variational distribution of `algorithm`.
281+
- `state`: Collection of states used by `algorithm`. This can be used to resume from a past call to `vi`.
282+
- `info`: Information generated while executing `algorithm`.
288283
"""
289284
function vi(
290285
rng::Random.AbstractRNG,
291286
model::DynamicPPL.Model,
292287
q,
293-
n_iterations::Int;
294-
objective=AdvancedVI.RepGradELBO(
295-
10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()
296-
),
297-
show_progress::Bool=PROGRESS[],
298-
optimizer=AdvancedVI.DoWG(),
299-
averager=AdvancedVI.PolynomialAveraging(),
300-
operator=AdvancedVI.ProximalLocationScaleEntropy(),
288+
max_iter::Int,
289+
args...;
301290
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
291+
algorithm=KLMinRepGradProxDescent(adtype; n_samples=10),
292+
show_progress::Bool=PROGRESS[],
302293
kwargs...,
303294
)
304295
return AdvancedVI.optimize(
305296
rng,
306-
LogDensityFunction(model),
307-
objective,
297+
algorithm,
298+
max_iter,
299+
LogDensityFunction(model; adtype),
308300
q,
309-
n_iterations;
301+
args...;
310302
show_progress=show_progress,
311-
adtype,
312-
optimizer,
313-
averager,
314-
operator,
315303
kwargs...,
316304
)
317305
end
318306

319-
function vi(model::DynamicPPL.Model, q, n_iterations::Int; kwargs...)
320-
return vi(Random.default_rng(), model, q, n_iterations; kwargs...)
307+
function vi(model::DynamicPPL.Model, q, max_iter::Int; kwargs...)
308+
return vi(Random.default_rng(), model, q, max_iter; kwargs...)
321309
end
322310

323311
end

src/variational/deprecated.jl

Lines changed: 0 additions & 61 deletions
This file was deleted.

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ AbstractMCMC = "5"
4444
AbstractPPL = "0.11, 0.12, 0.13"
4545
AdvancedMH = "0.6, 0.7, 0.8"
4646
AdvancedPS = "0.7"
47-
AdvancedVI = "0.4"
47+
AdvancedVI = "0.5"
4848
Aqua = "0.8"
4949
BangBang = "0.4"
5050
Bijectors = "0.14, 0.15"

test/runtests.jl

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ include("test_utils/models.jl")
1313
include("test_utils/numerical_tests.jl")
1414
include("test_utils/sampler.jl")
1515

16-
Turing.setprogress!(false)
16+
#Turing.setprogress!(false)
1717
included_paths, excluded_paths = parse_args(ARGS)
1818

1919
# Filter which tests to run and collect timing and allocations information to show in a
@@ -30,55 +30,59 @@ macro timeit_include(path::AbstractString)
3030
end
3131

3232
@testset "Turing" verbose = true begin
33-
@testset "Aqua" begin
34-
@timeit_include("Aqua.jl")
35-
end
33+
# @testset "Aqua" begin
34+
# @timeit_include("Aqua.jl")
35+
# end
3636

37-
@testset "AD" verbose = true begin
38-
@timeit_include("ad.jl")
39-
end
37+
# @testset "AD" verbose = true begin
38+
# @timeit_include("ad.jl")
39+
# end
4040

41-
@testset "essential" verbose = true begin
42-
@timeit_include("essential/container.jl")
43-
end
41+
# @testset "essential" verbose = true begin
42+
# @timeit_include("essential/container.jl")
43+
# end
4444

45-
@testset "samplers (without AD)" verbose = true begin
46-
@timeit_include("mcmc/particle_mcmc.jl")
47-
@timeit_include("mcmc/emcee.jl")
48-
@timeit_include("mcmc/ess.jl")
49-
@timeit_include("mcmc/is.jl")
50-
end
45+
# @testset "samplers (without AD)" verbose = true begin
46+
# @timeit_include("mcmc/particle_mcmc.jl")
47+
# @timeit_include("mcmc/emcee.jl")
48+
# @timeit_include("mcmc/ess.jl")
49+
# @timeit_include("mcmc/is.jl")
50+
# end
5151

5252
@timeit TIMEROUTPUT "inference" begin
53-
@testset "inference with samplers" verbose = true begin
54-
@timeit_include("mcmc/gibbs.jl")
55-
@timeit_include("mcmc/hmc.jl")
56-
@timeit_include("mcmc/Inference.jl")
57-
@timeit_include("mcmc/sghmc.jl")
58-
@timeit_include("mcmc/external_sampler.jl")
59-
@timeit_include("mcmc/mh.jl")
60-
@timeit_include("ext/dynamichmc.jl")
61-
@timeit_include("mcmc/repeat_sampler.jl")
62-
end
53+
# @testset "inference with samplers" verbose = true begin
54+
# @timeit_include("mcmc/gibbs.jl")
55+
# @timeit_include("mcmc/hmc.jl")
56+
# @timeit_include("mcmc/Inference.jl")
57+
# @timeit_include("mcmc/sghmc.jl")
58+
# @timeit_include("mcmc/external_sampler.jl")
59+
# @timeit_include("mcmc/mh.jl")
60+
# @timeit_include("ext/dynamichmc.jl")
61+
# @timeit_include("mcmc/repeat_sampler.jl")
62+
# end
6363

6464
@testset "variational algorithms" begin
6565
@timeit_include("variational/advi.jl")
6666
end
6767

68-
@testset "mode estimation" verbose = true begin
69-
@timeit_include("optimisation/Optimisation.jl")
70-
@timeit_include("ext/OptimInterface.jl")
71-
end
68+
# @testset "mode estimation" verbose = true begin
69+
# @timeit_include("optimisation/Optimisation.jl")
70+
# @timeit_include("ext/OptimInterface.jl")
71+
# end
7272
end
7373

74-
@testset "stdlib" verbose = true begin
75-
@timeit_include("stdlib/distributions.jl")
76-
@timeit_include("stdlib/RandomMeasures.jl")
77-
end
74+
# @testset "variational optimisers" begin
75+
# @timeit_include("variational/optimisers.jl")
76+
# end
7877

79-
@testset "utilities" begin
80-
@timeit_include("mcmc/utilities.jl")
81-
end
78+
# @testset "stdlib" verbose = true begin
79+
# @timeit_include("stdlib/distributions.jl")
80+
# @timeit_include("stdlib/RandomMeasures.jl")
81+
# end
82+
83+
# @testset "utilities" begin
84+
# @timeit_include("mcmc/utilities.jl")
85+
# end
8286
end
8387

8488
show(TIMEROUTPUT; compact=true, sortby=:firstexec)

0 commit comments

Comments
 (0)