|
1 | 1 |
|
2 | 2 | module Variational |
3 | 3 |
|
4 | | -using DynamicPPL |
| 4 | +using AdvancedVI: |
| 5 | + AdvancedVI, KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent |
5 | 6 | using ADTypes |
| 7 | +using Bijectors: Bijectors |
6 | 8 | using Distributions |
| 9 | +using DynamicPPL |
7 | 10 | using LinearAlgebra |
8 | 11 | using LogDensityProblems |
9 | 12 | using Random |
| 13 | +using ..Turing: DEFAULT_ADTYPE, PROGRESS |
10 | 14 |
|
11 | | -import ..Turing: DEFAULT_ADTYPE, PROGRESS |
12 | | - |
13 | | -import AdvancedVI |
14 | | -import Bijectors |
15 | | - |
16 | | -export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian |
17 | | - |
18 | | -include("deprecated.jl") |
| 15 | +export vi, |
| 16 | + q_locationscale, |
| 17 | + q_meanfield_gaussian, |
| 18 | + q_fullrank_gaussian, |
| 19 | + KLMinRepGradProxDescent, |
| 20 | + KLMinRepGradDescent, |
| 21 | + KLMinScoreGradDescent |
19 | 22 |
|
20 | 23 | """ |
21 | 24 | q_initialize_scale( |
@@ -248,76 +251,61 @@ end |
248 | 251 | """ |
249 | 252 | vi( |
250 | 253 | [rng::Random.AbstractRNG,] |
251 | | - model::DynamicPPL.Model; |
| 254 | + model::DynamicPPL.Model, |
252 | 255 | q, |
253 | | - n_iterations::Int; |
254 | | - objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO( |
255 | | - 10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient() |
256 | | - ), |
| 256 | + max_iter::Int; |
| 257 | + algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(DEFAULT_ADTYPE; n_samples=10), |
257 | 258 | show_progress::Bool = Turing.PROGRESS[], |
258 | | - optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(), |
259 | | - averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(), |
260 | | - operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(), |
261 | | - adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE, |
262 | 259 | kwargs... |
263 | 260 | ) |
264 | 261 |
|
265 | | -Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`. |
| 262 | +Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`. |
266 | 263 | This is a thin wrapper around `AdvancedVI.optimize`. |
| 264 | +The default `algorithm` assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`. |
| 265 | +For other variational families, refer to `AdvancedVI` to determine the best algorithm and options. |
267 | 266 |
|
268 | 267 | # Arguments |
269 | 268 | - `model`: The target `DynamicPPL.Model`. |
270 | 269 | - `q`: The initial variational approximation. |
271 | | -- `n_iterations`: Number of optimization steps. |
| 270 | +- `max_iter`: Maximum number of steps. |
272 | 271 |
|
273 | 272 | # Keyword Arguments |
274 | | -- `objective`: Variational objective to be optimized. |
| 273 | +- `algorithm`: Variational inference algorithm. |
275 | 274 | - `show_progress`: Whether to show the progress bar. |
276 | | -- `optimizer`: Optimization algorithm. |
277 | | -- `averager`: Parameter averaging strategy. |
278 | | -- `operator`: Operator applied after each optimization step. |
279 | | -- `adtype`: Automatic differentiation backend. |
| 275 | +- `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiation the variational objective. |
280 | 276 |
|
281 | 277 | See the docs of `AdvancedVI.optimize` for additional keyword arguments. |
282 | 278 |
|
283 | 279 | # Returns |
284 | | -- `q`: Variational distribution formed by the last iterate of the optimization run. |
285 | | -- `q_avg`: Variational distribution formed by the averaged iterates according to `averager`. |
286 | | -- `state`: Collection of states used for optimization. This can be used to resume from a past call to `vi`. |
287 | | -- `info`: Information generated during the optimization run. |
| 280 | +- `q`: Output variational distribution of `algorithm`. |
| 281 | +- `state`: Collection of states used by `algorithm`. This can be used to resume from a past call to `vi`. |
| 282 | +- `info`: Information generated while executing `algorithm`. |
288 | 283 | """ |
289 | 284 | function vi( |
290 | 285 | rng::Random.AbstractRNG, |
291 | 286 | model::DynamicPPL.Model, |
292 | 287 | q, |
293 | | - n_iterations::Int; |
294 | | - objective=AdvancedVI.RepGradELBO( |
295 | | - 10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient() |
296 | | - ), |
297 | | - show_progress::Bool=PROGRESS[], |
298 | | - optimizer=AdvancedVI.DoWG(), |
299 | | - averager=AdvancedVI.PolynomialAveraging(), |
300 | | - operator=AdvancedVI.ProximalLocationScaleEntropy(), |
| 288 | + max_iter::Int, |
| 289 | + args...; |
301 | 290 | adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, |
| 291 | + algorithm=KLMinRepGradProxDescent(adtype; n_samples=10), |
| 292 | + show_progress::Bool=PROGRESS[], |
302 | 293 | kwargs..., |
303 | 294 | ) |
304 | 295 | return AdvancedVI.optimize( |
305 | 296 | rng, |
306 | | - LogDensityFunction(model), |
307 | | - objective, |
| 297 | + algorithm, |
| 298 | + max_iter, |
| 299 | + LogDensityFunction(model; adtype), |
308 | 300 | q, |
309 | | - n_iterations; |
| 301 | + args...; |
310 | 302 | show_progress=show_progress, |
311 | | - adtype, |
312 | | - optimizer, |
313 | | - averager, |
314 | | - operator, |
315 | 303 | kwargs..., |
316 | 304 | ) |
317 | 305 | end |
318 | 306 |
|
319 | | -function vi(model::DynamicPPL.Model, q, n_iterations::Int; kwargs...) |
320 | | - return vi(Random.default_rng(), model, q, n_iterations; kwargs...) |
| 307 | +function vi(model::DynamicPPL.Model, q, max_iter::Int; kwargs...) |
| 308 | + return vi(Random.default_rng(), model, q, max_iter; kwargs...) |
321 | 309 | end |
322 | 310 |
|
323 | 311 | end |
0 commit comments