Skip to content

Commit 14973ad

Browse files
mhauruAoife
authored andcommitted
Fix VI tutorial
1 parent df1f9cd commit 14973ad

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tutorials/variational-inference/index.qmd

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ Usually, `q_avg` will perform better than the last-iterate `q_last`.
182182
For instance, we can compare the ELBO of the two:
183183
```{julia}
184184
@info("Objective of q_avg and q_last",
185-
ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), q_avg, Turing.Variational.make_logdensity(m)),
186-
ELBO_q_last = estimate_objective(AdvancedVI.RepGradELBO(32), q_last, Turing.Variational.make_logdensity(m))
185+
ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), q_avg, LogDensityFunction(m)),
186+
ELBO_q_last = estimate_objective(AdvancedVI.RepGradELBO(32), q_last, LogDensityFunction(m))
187187
)
188188
```
189189
We can see that `ELBO_q_avg` is slightly more optimal.
@@ -205,9 +205,9 @@ For example, the following callback function estimates the ELBO on `q_avg` every
205205
```{julia}
206206
function callback(; stat, averaged_params, restructure, kwargs...)
207207
if mod(stat.iteration, 10) == 1
208-
q_avg = restructure(averaged_params)
209-
obj = AdvancedVI.RepGradELBO(128)
210-
elbo_avg = estimate_objective(obj, q_avg, Turing.Variational.make_logdensity(m))
208+
q_avg = restructure(averaged_params)
209+
obj = AdvancedVI.RepGradELBO(128)
210+
elbo_avg = estimate_objective(obj, q_avg, LogDensityFunction(m))
211211
(elbo_avg = elbo_avg,)
212212
else
213213
nothing
@@ -223,7 +223,7 @@ q_mf, _, info_mf, _ = vi(m, q_init, n_iters; show_progress=false, callback=callb
223223

224224
Let's plot the result:
225225
```{julia}
226-
iters = 1:10:length(info_mf)
226+
iters = 1:10:length(info_mf)
227227
elbo_mf = [i.elbo_avg for i in info_mf[iters]]
228228
Plots.plot!(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="callback", ylims=(-200,Inf))
229229
```
@@ -247,7 +247,7 @@ _, _, info_adam, _ = vi(m, q_init, n_iters; show_progress=false, callback=callba
247247
```
248248

249249
```{julia}
250-
iters = 1:10:length(info_mf)
250+
iters = 1:10:length(info_mf)
251251
elbo_adam = [i.elbo_avg for i in info_adam[iters]]
252252
Plots.plot(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="DoWG")
253253
Plots.plot!(iters, elbo_adam, xlabel="Iterations", ylabel="ELBO", label="Adam")

0 commit comments

Comments
 (0)