@@ -182,8 +182,8 @@ Usually, `q_avg` will perform better than the last-iterate `q_last`.
182
182
For instance, we can compare the ELBO of the two:
183
183
``` {julia}
184
184
@info("Objective of q_avg and q_last",
185
- ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), q_avg, Turing.Variational.make_logdensity (m)),
186
- ELBO_q_last = estimate_objective(AdvancedVI.RepGradELBO(32), q_last, Turing.Variational.make_logdensity (m))
185
+ ELBO_q_avg = estimate_objective(AdvancedVI.RepGradELBO(32), q_avg, LogDensityFunction (m)),
186
+ ELBO_q_last = estimate_objective(AdvancedVI.RepGradELBO(32), q_last, LogDensityFunction (m))
187
187
)
188
188
```
189
189
We can see that ` ELBO_q_avg ` is slightly more optimal.
@@ -205,9 +205,9 @@ For example, the following callback function estimates the ELBO on `q_avg` every
205
205
``` {julia}
206
206
function callback(; stat, averaged_params, restructure, kwargs...)
207
207
if mod(stat.iteration, 10) == 1
208
- q_avg = restructure(averaged_params)
209
- obj = AdvancedVI.RepGradELBO(128)
210
- elbo_avg = estimate_objective(obj, q_avg, Turing.Variational.make_logdensity (m))
208
+ q_avg = restructure(averaged_params)
209
+ obj = AdvancedVI.RepGradELBO(128)
210
+ elbo_avg = estimate_objective(obj, q_avg, LogDensityFunction (m))
211
211
(elbo_avg = elbo_avg,)
212
212
else
213
213
nothing
@@ -223,7 +223,7 @@ q_mf, _, info_mf, _ = vi(m, q_init, n_iters; show_progress=false, callback=callb
223
223
224
224
Let's plot the result:
225
225
``` {julia}
226
- iters = 1:10:length(info_mf)
226
+ iters = 1:10:length(info_mf)
227
227
elbo_mf = [i.elbo_avg for i in info_mf[iters]]
228
228
Plots.plot!(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="callback", ylims=(-200,Inf))
229
229
```
@@ -247,7 +247,7 @@ _, _, info_adam, _ = vi(m, q_init, n_iters; show_progress=false, callback=callba
247
247
```
248
248
249
249
``` {julia}
250
- iters = 1:10:length(info_mf)
250
+ iters = 1:10:length(info_mf)
251
251
elbo_adam = [i.elbo_avg for i in info_adam[iters]]
252
252
Plots.plot(iters, elbo_mf, xlabel="Iterations", ylabel="ELBO", label="DoWG")
253
253
Plots.plot!(iters, elbo_adam, xlabel="Iterations", ylabel="ELBO", label="Adam")
0 commit comments