Skip to content

Commit d870045

Browse files
committed
update docs for vi
1 parent 3e30e04 commit d870045

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/variational/VariationalInference.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,17 @@ end
254254
model::DynamicPPL.Model,
255255
q,
256256
max_iter::Int;
257-
algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(DEFAULT_ADTYPE; n_samples=10),
257+
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
258+
algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(
259+
adtype; n_samples=10
260+
),
258261
show_progress::Bool = Turing.PROGRESS[],
259262
kwargs...
260263
)
261264
262265
Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`.
263266
This is a thin wrapper around `AdvancedVI.optimize`.
264-
The default `algorithm` assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`.
267+
The default `algorithm`, `KLMinRepGradProxDescent` ([relevant docs](https://turinglang.org/AdvancedVI.jl/dev/klminrepgradproxdescent/)), assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`.
265268
For other variational families, refer to `AdvancedVI` to determine the best algorithm and options.
266269
267270
# Arguments
@@ -270,9 +273,9 @@ For other variational families, refer to `AdvancedVI` to determine the best algo
270273
- `max_iter`: Maximum number of steps.
271274
272275
# Keyword Arguments
276+
- `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiation the variational objective.
273277
- `algorithm`: Variational inference algorithm.
274278
- `show_progress`: Whether to show the progress bar.
275-
- `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiation the variational objective.
276279
277280
See the docs of `AdvancedVI.optimize` for additional keyword arguments.
278281
@@ -288,7 +291,9 @@ function vi(
288291
max_iter::Int,
289292
args...;
290293
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
291-
algorithm=KLMinRepGradProxDescent(adtype; n_samples=10),
294+
algorithm::AdvancedVI.AbstractVariationalAlgorithm=KLMinRepGradProxDescent(
295+
adtype; n_samples=10
296+
),
292297
show_progress::Bool=PROGRESS[],
293298
kwargs...,
294299
)

0 commit comments

Comments
 (0)