Skip to content

Commit 9ad7e97

Browse files
committed
add callbacks VI and training; fix docstrings
1 parent 38e6571 commit 9ad7e97

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

src/inference/train.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,29 @@
33
update::ParamUpdate,
44
num_epoch, epoch_size, num_minibatch, minibatch_size; verbose::Bool=false)
55
6-
Train the given generative function to maximize the expected conditional log probability (density) that `gen_fn` generates the assignment `constraints` given inputs, where the expectation is taken under the output distribution of `data_generator`.
6+
Train the given generative function to maximize the expected conditional log
7+
probability (density) that `gen_fn` generates the assignment `constraints`
8+
given inputs, where the expectation is taken under the output distribution of
9+
`data_generator`.
10+
11+
The function `data_generator` is a function of no arguments that returns a
12+
tuple `(inputs, constraints)` where `inputs` is a `Tuple` of inputs (arguments)
13+
to `gen_fn`, and `constraints` is an `ChoiceMap`.
714
8-
The function `data_generator` is a function of no arguments that returns a tuple `(inputs, constraints)` where `inputs` is a `Tuple` of inputs (arguments) to `gen_fn`, and `constraints` is an `ChoiceMap`.
915
`conf` configures the optimization algorithm used.
16+
1017
`param_lists` is a map from generative function to lists of its parameters.
11-
This is equivalent to minimizing the expected KL divergence from the conditional distribution `constraints | inputs` of the data generator to the distribution represented by the generative function, where the expectation is taken under the marginal distribution on `inputs` determined by the data generator.
18+
This is equivalent to minimizing the expected KL divergence from the
19+
conditional distribution `constraints | inputs` of the data generator to the
20+
distribution represented by the generative function, where the expectation is
21+
taken under the marginal distribution on `inputs` determined by the data
22+
generator.
1223
"""
1324
function train!(gen_fn::GenerativeFunction, data_generator::Function,
1425
update::ParamUpdate;
1526
num_epoch=1, epoch_size=1, num_minibatch=1, minibatch_size=1,
16-
evaluation_size=epoch_size, verbose=false)
27+
evaluation_size=epoch_size, verbose=false,
28+
callback=(epoch, minibatch, minibatch_objective) -> nothing)
1729

1830
history = Vector{Float64}(undef, num_epoch)
1931
for epoch=1:num_epoch
@@ -37,11 +49,15 @@ function train!(gen_fn::GenerativeFunction, data_generator::Function,
3749
minibatch_idx = permuted[1:minibatch_size]
3850
minibatch_inputs = epoch_inputs[minibatch_idx]
3951
minibatch_choice_maps = epoch_choice_maps[minibatch_idx]
52+
minibatch_objective = 0.0
4053
for (inputs, constraints) in zip(minibatch_inputs, minibatch_choice_maps)
41-
(trace, _) = generate(gen_fn, inputs, constraints)
54+
(trace, weight) = generate(gen_fn, inputs, constraints)
55+
minibatch_objective += weight
4256
accumulate_param_gradients!(trace)
4357
end
4458
apply!(update)
59+
minibatch_objective /= minibatch_size
60+
callback(epoch, minibatch, minibatch_objective)
4561
end
4662

4763
# evaluate score on held out data

src/inference/variational.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,20 @@ end
9494
observations::ChoiceMap,
9595
var_model::GenerativeFunction, var_model_args::Tuple,
9696
update::ParamUpdate;
97-
iters=1000, samples_per_iter=100, verbose=false)
97+
iters=1000, samples_per_iter=100, verbose=false,
98+
callback=(iter, traces, elbo_estimate) -> nothing)
9899
99-
Fit the parameters of a generative function (`var_model`) to the posterior distribution implied by the given model and observations using stochastic gradient methods.
100+
Fit the parameters of a generative function (`var_model`) to the posterior
101+
distribution implied by the given model and observations using stochastic
102+
gradient methods.
100103
"""
101104
function black_box_vi!(
102105
model::GenerativeFunction, model_args::Tuple,
103106
observations::ChoiceMap,
104107
var_model::GenerativeFunction, var_model_args::Tuple,
105108
update::ParamUpdate;
106-
iters=1000, samples_per_iter=100, verbose=false)
109+
iters=1000, samples_per_iter=100, verbose=false,
110+
callback=(iter, traces, elbo_estimate) -> nothing)
107111

108112
traces = Vector{Any}(undef, samples_per_iter)
109113
elbo_history = Vector{Float64}(undef, iters)
@@ -126,6 +130,9 @@ function black_box_vi!(
126130
# print it
127131
verbose && println("iter $iter; est objective: $elbo_estimate")
128132

133+
# callback
134+
callback(iter, traces, elbo_estimate)
135+
129136
# do an update
130137
apply!(update)
131138
end
@@ -139,17 +146,23 @@ end
139146
observations::ChoiceMap,
140147
var_model::GenerativeFunction, var_model_args::Tuple,
141148
update::ParamUpdate, num_samples::Int;
142-
iters=1000, samples_per_iter=100, verbose=false)
149+
iters=1000, samples_per_iter=100, verbose=false,
150+
callback=(iter, traces, elbo_estimate) -> nothing)
143151
144-
Fit the parameters of a generative function (`var_model`) to the posterior distribution implied by the given model and observations using stochastic gradient methods applied to the [Variational Inference with Monte Carlo Objectives](https://arxiv.org/abs/1602.06725) lower bound on the marginal likelihood.
152+
Fit the parameters of a generative function (`var_model`) to the posterior
153+
distribution implied by the given model and observations using stochastic
154+
gradient methods applied to the [Variational Inference with Monte Carlo
155+
Objectives](https://arxiv.org/abs/1602.06725) lower bound on the marginal
156+
likelihood.
145157
"""
146158
function black_box_vimco!(
147159
model::GenerativeFunction, model_args::Tuple,
148160
observations::ChoiceMap,
149161
var_model::GenerativeFunction, var_model_args::Tuple,
150162
update::ParamUpdate, num_samples::Int;
151163
iters=1000, samples_per_iter=100, verbose=false,
152-
geometric=true)
164+
geometric=true,
165+
callback=(iter, traces, elbo_estimate) -> nothing)
153166

154167
traces = Vector{Any}(undef, samples_per_iter)
155168
iwelbo_history = Vector{Float64}(undef, iters)
@@ -172,6 +185,9 @@ function black_box_vimco!(
172185
# print it
173186
verbose && println("iter $iter; est objective: $iwelbo_estimate")
174187

188+
# callback
189+
callback(iter, traces, iwelbo_estimate)
190+
175191
# do an update
176192
apply!(update)
177193
end

0 commit comments

Comments
 (0)