Skip to content

Commit 6df0913

Browse files
authored
Update VI Tutorial (#625)
* update Turing version * update VI tutorial to automatically extract varnames
1 parent 2788f7c commit 6df0913

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tutorials/variational-inference/index.qmd

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,13 @@ avg[union(sym2range[:coefficients]...)]
332332

333333
For further convenience, we can wrap the samples into a `Chains` object to summarize the results.
334334
```{julia}
335-
varnames = vcat(["σ²", "intercept"], ["coefficients[$i]" for i in 1:n_vars])
335+
varinf = Turing.DynamicPPL.VarInfo(m)
336+
vns_and_values = Turing.DynamicPPL.varname_and_value_leaves(Turing.DynamicPPL.values_as(varinf, OrderedDict))
337+
varnames = map(first, vns_and_values)
336338
vi_chain = Chains(reshape(z', (size(z,2), size(z,1), 1)), varnames)
337339
```
338340
(Since we're drawing independent samples, we can simply ignore the ESS and Rhat metrics.)
341+
Unfortunately, extracting `varnames` is a bit verbose at the moment, but hopefully will become simpler in the near future.
339342

340343
Let's compare this against samples from `NUTS`:
341344

0 commit comments

Comments
 (0)