Skip to content

Commit fdd684e

Browse files
committed
Add completed VI tutorial.
1 parent 07ae12e commit fdd684e

File tree

2 files changed

+330
-0
lines changed

2 files changed

+330
-0
lines changed

docs/pages.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pages = [
44
"Getting Started" => "tutorials/getting_started.md",
55
"Introduction to Modeling in Gen" => "tutorials/modeling_in_gen.md",
66
"Object Tracking with SMC" => "tutorials/smc.md",
7+
"Variational Inference in Gen" => "tutorials/vi.md",
78
"Learning Generative Functions" => "tutorials/learning_gen_fns.md",
89
"Speeding Up Inference with the SML" => "tutorials/scaling_with_sml.md",
910
],

docs/src/tutorials/vi.md

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# [Variational Inference in Gen](@id vi_tutorial)
2+
3+
Variational inference (VI) involves optimizing the parameters of a variational family to maximize a lower bound on the marginal likelihood called the ELBO. In Gen, variational families are represented as generative functions, and variational inference typically involves optimizing the trainable parameters of generative functions.
4+
5+
```@setup vi_tutorial
6+
using Gen, Random
7+
Random.seed!(0)
8+
```
9+
10+
## A Simple Example of VI
11+
12+
Let's begin with a simple example that illustrates how to use Gen's [`black_box_vi!`](@ref) function to perform variational inference. In variational inference, we have a target distribution ``P(x)`` that we wish to approximate with some variational distribution ``Q(x; \phi)`` with trainable parameters ``\phi``.
13+
14+
In many cases, this target distribution is a posterior distribution ``P(x | y)`` given a fixed set of observations ``y``. But in this example, we assume we know ``P(x)`` exactly, and optimize ``\phi`` so that ``Q(x; \phi)`` fits ``P(x)``.
15+
16+
We first define the **target distribution** ``P(x)`` as a normal distribution with
17+
with a mean of `-1` and a standard deviation of `exp(0.5)`:
18+
19+
```@example vi_tutorial
20+
@gen function target()
21+
x ~ normal(-1, exp(0.5))
22+
end
23+
nothing # hide
24+
```
25+
26+
We now define a **variational family**, also known as a *guide*, as a generative function ``Q(x; \phi)`` parameterized by a set of trainable parameters ``\phi``. This requires (i) picking the functional form of the variational distribution (e.g. normal, Cauchy, etc.), (ii) choosing how the distribution is parameterized.
27+
28+
Our target distribution is normal, so we make our variational family normally distributed as well. We also define two variational parameters, `x_mu` and `x_log_std`, which are the mean and log standard deviation of our variational distribution.
29+
30+
```@example vi_tutorial
31+
@gen function approx()
32+
@param x_mu::Float64
33+
@param x_log_std::Float64
34+
x ~ normal(x_mu, exp(x_log_std))
35+
end
36+
nothing # hide
37+
```
38+
39+
Since `x_mu` and `x_log_std`are not fixed to particular values, this generative function defines a *family* of distributions, not just one. Note that we intentionally chose to parameterize the distribution by the log standard deviation `x_log_std`, so that every parameter has full support over the real line, and we can perform unconstrained optimization of the parameters.
40+
41+
To perform variational inference, we need to initialize the variational parameters to their starting values:
42+
43+
```@example vi_tutorial
44+
init_param!(approx, :x_mu, 0.0)
45+
init_param!(approx, :x_log_std, 0.0)
46+
nothing # hide
47+
```
48+
49+
Now we can use the [`black_box_vi!`](@ref) function to perform variational inference using [`GradientDescent`](@ref) to update the variational parameters.
50+
51+
```@example vi_tutorial
52+
observations = choicemap()
53+
param_update = ParamUpdate(GradientDescent(1., 1000), approx)
54+
black_box_vi!(target, (), observations, approx, (), param_update;
55+
iters=200, samples_per_iter=100, verbose=false)
56+
nothing # hide
57+
```
58+
59+
We can now inspect the resulting variational parameters, and see if we have recovered the parameters of the target distribution:
60+
61+
```@example vi_tutorial
62+
x_mu = get_param(approx, :x_mu)
63+
x_log_std = get_param(approx, :x_log_std)
64+
@show x_mu x_log_std;
65+
nothing # hide
66+
```
67+
68+
As expected, we have recovered the parameters of the target distribution.
69+
70+
## Posterior Inference with VI
71+
72+
In the above example, we used a target distribution ``P(x)`` that we had full knowledge about. When performing posterior inference, however, we typically only have the ability to sample from a generative model ``x, y \sim P(x) P(y | x)``, and to evaluate the joint probability ``P(x, y)``, but not the ability to evaluate or sample from the posterior ``P(x | y)`` for a fixed obesrvation ``y``.
73+
74+
Variational inference can address this by approximating ``P(x | y)`` with ``Q(x; \phi)``, allowing us to sample and evaluate ``Q(x; \phi)`` instead. This is done by maximizing a quantity known as the **evidence lower bound** or **ELBO**, which is a lower bound on the log marginal likelihood ``\log P(y)`` of the observations ``y``. The ELBO can be written in multiple equivalent forms:
75+
76+
```math
77+
\begin{aligned}
78+
\operatorname{ELBO}(\phi; y)
79+
&= \mathbb{E}_{x \sim Q(x; \phi)}\left[\log \frac{P(x, y)}{Q(x; \phi)}\right] \\
80+
&= \mathbb{E}_{x \sim Q(x; \phi)}[\log P(x, y)] + \operatorname{H}[Q(x; \phi)] \\
81+
&= \log P(y) - \operatorname{KL}[Q(x; \phi) || P(x | y)]
82+
\end{aligned}
83+
```
84+
85+
Here, ``\operatorname{H}[Q(x; \phi)]`` is the entropy of the variational distribution ``Q(x; \phi)``, and ``\operatorname{KL}[Q(x; \phi) || P(x | y)]`` is the Kullback-Leibler divergence between the variational distribution ``Q(x; \phi)`` and the target distribution ``P(x | y)``. From the third line, we can see that the ELBO is a lower bound on ``\log P(y)``, and that maximizing the ELBO is equivalent to minimizing the KL divergence between ``Q(x; \phi)`` and ``P(x | y)``.
86+
87+
Let's test this for a generative model ``P(x, y)`` where it is possible (with a bit of work) to analytically calculate the posterior ``P(y | x)``:
88+
89+
```@example vi_tutorial
90+
@gen function model(n::Int)
91+
x ~ normal(0, 1)
92+
for i in 1:n
93+
{(:y, i)} ~ normal(x, 0.5)
94+
end
95+
end
96+
nothing # hide
97+
```
98+
99+
In this normal-normal model, an unknown mean ``x`` is sampled from a ``\operatorname{Normal}(0, 1)`` prior. Then we draw ``n`` datapoints ``y_{1:n}`` from a normal distribution centered around ``x`` with a standard deviation of 0.5. Our task is to infer the posterior distribution over ``x`` given that we have observed ``y_{1:n}``. We'll reuse the same variational family as before:
100+
101+
```@example vi_tutorial
102+
@gen function approx()
103+
@param x_mu::Float64
104+
@param x_log_std::Float64
105+
x ~ normal(x_mu, exp(x_log_std))
106+
end
107+
nothing # hide
108+
```
109+
110+
Suppose we observe ``n = 6`` datapoints ``y_{1:6}`` with the following values:
111+
```@example vi_tutorial
112+
ys = [3.12, 2.25, 2.21, 1.55, 2.15, 1.06]
113+
nothing # hide
114+
```
115+
116+
It is possible to show analytically that the posterior ``P(x | y_{1:n})`` is normally distributed with mean ``\mu_n = \frac{4n}{1 + 4n} \bar y`` and standard deviation ``\sigma_n = \frac{1}{\sqrt{1 + 4n}}``, where ``\bar y`` is the mean of ``y_{1:n}``:
117+
118+
```@example vi_tutorial
119+
n = length(ys)
120+
x_mu_expected = 4*n / (1 + 4*n) * (sum(ys) / n)
121+
x_std_expected = 1/(sqrt((1 + 4*n)))
122+
@show x_mu_expected x_std_expected;
123+
nothing # hide
124+
```
125+
126+
Let's see whether variational inference can reproduce these values. We first construct a choicemap of our observations:
127+
128+
```@example vi_tutorial
129+
observations = choicemap()
130+
for (i, y) in enumerate(ys)
131+
observations[(:y, i)] = y
132+
end
133+
nothing # hide
134+
```
135+
136+
Next, we configure our [`GradientDescent`](@ref) optimizer. Since this is a more complicated optimization proplem, we use a smaller initial step size of 0.01:
137+
138+
```@example vi_tutorial
139+
step_size_init = 0.01
140+
step_size_beta = 1000
141+
update_config = GradientDescent(step_size_init, step_size_beta)
142+
nothing # hide
143+
```
144+
145+
We then initialize the parameters of our variational approximation, and pass our model, observations, and variational family to [`black_box_vi!`](@ref).
146+
147+
```@example vi_tutorial
148+
init_param!(approx, :x_mu, 0.0)
149+
init_param!(approx, :x_log_std, 0.0)
150+
param_update = ParamUpdate(update_config, approx);
151+
elbo_est, _, elbo_history =
152+
black_box_vi!(model, (n,), observations, approx, (), param_update;
153+
iters=500, samples_per_iter=200, verbose=false);
154+
nothing # hide
155+
```
156+
157+
As expected, the ELBO estimate increases over time, eventually converging to a value around -9.9:
158+
159+
```@example vi_tutorial
160+
for t in [1; 50:50:500]
161+
println("iter $(lpad(t, 3)): elbo est. = $(elbo_history[t])")
162+
end
163+
println("final elbo est. = $elbo_est")
164+
```
165+
166+
Inspecting the resulting variational parameters, we find that they are reasonable approximations to the parameters of the true posterior:
167+
168+
```@example vi_tutorial
169+
x_mu_approx = get_param(approx, :x_mu)
170+
Δx_mu = x_mu_approx - x_mu_expected
171+
172+
x_log_std_approx = get_param(approx, :x_log_std)
173+
x_std_approx = exp(x_log_std_approx)
174+
Δx_std = x_std_approx - x_std_expected
175+
176+
@show (x_mu_approx, Δx_mu) (x_std_approx, Δx_std);
177+
nothing # hide
178+
```
179+
180+
## Amortized Variational Inference
181+
182+
In standard variational inference, we have to optimize the variational parameters ``\phi`` for each new inference problem. Depending on how difficult the optimization problem is, this may be costly.
183+
184+
As an alternative, we can perform **amortized variational inference**: Instead of optimizing ``\phi`` for each set of observations ``y`` that we encounter, we learn a *function* ``f_\varphi(y)`` that outputs a set of distribution parameters ``\phi_y`` for each ``y``, and optimize the parameters of the function ``\varphi``. We do this over a dataset of ``K`` independently distributed observation sets ``Y = \{y^1, ..., y^K\}``, maximizing the expected ELBO over this dataset:
185+
186+
```math
187+
\begin{aligned}
188+
\operatorname{A-ELBO}(\varphi; Y)
189+
&= \frac{1}{K} \sum_{k=1}^{K} \operatorname{ELBO}(\varphi; y^k) \\
190+
&= \frac{1}{K} \left[\log P(Y) - \sum_{k=1}^{K} \operatorname{KL}[Q(x; f_{\varphi}(y^k)) || P(x | y^k)] \right]
191+
\end{aligned}
192+
```
193+
194+
We will perform amortized VI over the same generative `model` we defined earlier:
195+
196+
```@example vi_tutorial
197+
@gen function model(n::Int)
198+
x ~ normal(0, 1)
199+
for i in 1:n
200+
{(:y, i)} ~ normal(x, 0.5)
201+
end
202+
end
203+
nothing # hide
204+
```
205+
206+
Since amortized VI is performed over a dataset of `K` observation sets ``\{y^1, ..., y^K\}``, where each ``y^k`` has ``n`` datapoints ``(y^k_1, ..., y^k_n)`` , we need to nest `model` within a [`Map`](@ref) combinator that repeats `model` ``K`` times:
207+
208+
```@example vi_tutorial
209+
mapped_model = Map(model)
210+
nothing # hide
211+
```
212+
213+
Let's generate a synthetic dataset of ``K = 10`` observation sets, each with ``n = 6`` datapoints:
214+
215+
```@example vi_tutorial
216+
# Simulate 10 observation sets of length 6
217+
K, n = 10, 6
218+
mapped_trace = simulate(mapped_model, (fill(n, K),))
219+
observations = get_choices(mapped_trace)
220+
221+
# Select just the `y` values, excluding the generated `x` values
222+
sel = select((k => (:y, i) for i in 1:n for k in 1:K)...)
223+
observations = get_selected(observations, sel)
224+
all_ys = [[observations[k => (:y, i)] for i in 1:n] for k in 1:K]
225+
nothing # hide
226+
```
227+
228+
Now let's define our amortized approximation, which takes in an observation set `ys`, and computes the parameters of a normal distribution over `x` as a function of `ys`:
229+
230+
```@example vi_tutorial
231+
@gen function amortized_approx(ys)
232+
@param x_mu_bias::Float64
233+
@param x_mu_coeff::Float64
234+
@param x_log_std::Float64
235+
x_mu = x_mu_bias + x_mu_coeff * sum(ys)
236+
x ~ normal(x_mu, exp(x_log_std))
237+
return (x_mu, x_log_std)
238+
end
239+
nothing # hide
240+
```
241+
242+
Similar to our `model`, we need to wrap this variational approximation in a [`Map`](@ref) combinator:
243+
244+
```@example vi_tutorial
245+
mapped_approx = Map(amortized_approx)
246+
nothing # hide
247+
```
248+
249+
In our choice of function ``f_\varphi(y)``, we exploit the fact that the posterior mean `x_mu` should depend on the sum of the values in `ys`, along with the knowledge that `x_log_std` does not depend on `ys`. We could have chosen a more complex function, such as full-rank linear regression, or a neural network, but this would make optimization more difficult. Given this choice of function, the optimal parameters ``\varphi^*`` can be computed analytically:
250+
251+
```@example vi_tutorial
252+
n = 6
253+
254+
x_mu_bias_optimal = 0.0
255+
x_mu_coeff_optimal = 4 / (1 + 4*n)
256+
257+
x_std_optimal = 1/(sqrt((1 + 4*n)))
258+
x_log_std_optimal = log(x_std_optimal)
259+
260+
@show x_mu_bias_optimal x_mu_coeff_optimal x_log_std_optimal;
261+
nothing # hide
262+
```
263+
264+
We can now fit our variational approximation via [`black_box_vi!`](@ref): We initialize the variational parameters, then configure our parameter update to update the parameters of `amortized_approx`:
265+
266+
```@example vi_tutorial
267+
# Configure parameter update to optimize the parameters of `amortized_approx`
268+
step_size_init = 1e-4
269+
step_size_beta = 1000
270+
update_config = GradientDescent(step_size_init, step_size_beta)
271+
272+
# Initialize the amortized variational parameters, then the parameter update
273+
init_param!(amortized_approx, :x_mu_bias, 0.0);
274+
init_param!(amortized_approx, :x_mu_coeff, 0.0);
275+
init_param!(amortized_approx, :x_log_std, 0.0);
276+
param_update = ParamUpdate(update_config, amortized_approx);
277+
278+
# Run amortized black-box variational inference over the synthetic observations
279+
mapped_model_args = (fill(n, K), )
280+
mapped_approx_args = (all_ys, )
281+
elbo_est, _, elbo_history =
282+
black_box_vi!(mapped_model, mapped_model_args, observations,
283+
mapped_approx, mapped_approx_args, param_update;
284+
iters=500, samples_per_iter=100, verbose=false);
285+
nothing # hide
286+
```
287+
288+
Once again, the ELBO estimate increases and eventually converges:
289+
290+
```@example vi_tutorial
291+
for t in [1; 50:50:500]
292+
println("iter $(lpad(t, 3)): elbo est. = $(elbo_history[t])")
293+
end
294+
println("final elbo est. = $elbo_est")
295+
```
296+
297+
Our amortized variational parameters ``\varphi`` are also fairly close to their optimal values ``\varphi^*``:
298+
299+
```@example vi_tutorial
300+
x_mu_bias = get_param(amortized_approx, :x_mu_bias)
301+
Δx_mu_bias = x_mu_bias - x_mu_bias_optimal
302+
303+
x_mu_coeff = get_param(amortized_approx, :x_mu_coeff)
304+
Δx_mu_coeff = x_mu_coeff - x_mu_coeff_optimal
305+
306+
x_log_std = get_param(amortized_approx, :x_log_std)
307+
Δx_log_std = x_log_std - x_log_std_optimal
308+
309+
@show (x_mu_bias, Δx_mu_bias) (x_mu_coeff, Δx_mu_coeff) (x_log_std, Δx_log_std);
310+
nothing # hide
311+
```
312+
313+
If we now call `amortized_approx` with our observation set `ys` from the previous section, we should get something close to what standard variational inference produced by optimizing the paramaters of `approx` directly:
314+
315+
```@example vi_tutorial
316+
x_mu_amortized, x_log_std_amortized = amortized_approx(ys)
317+
x_std_amortized = exp(x_log_std_amortized)
318+
319+
@show x_mu_amortized x_std_amortized;
320+
@show x_mu_approx x_std_approx;
321+
@show x_mu_expected x_std_expected;
322+
nothing # hide
323+
```
324+
325+
Both amortized VI and standard VI produce parameter estimates that are reasonably close to the paramters of the true posterior.
326+
327+
## Reparametrization Trick
328+
329+
To use the reparametrization trick to reduce the variance of gradient estimators, users currently need to write two versions of their variational family, one that is reparametrized and one that is not. Gen.jl does not currently include inference library support for this. We plan to add automated support for reparametrization and other variance reduction techniques in the future.

0 commit comments

Comments
 (0)