Skip to content

Commit 64ee5e0

Browse files
committed
add multi-threading to VI and importance sampling with tests
1 parent 22b9c98 commit 64ee5e0

File tree

6 files changed

+179
-102
lines changed

6 files changed

+179
-102
lines changed

src/inference/variational.jl

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function single_sample_gradient_estimate!(
1616
accumulate_param_gradients!(var_trace, nothing, log_weight * scale_factor)
1717

1818
# unbiased estimate of objective function, and trace
19-
(log_weight, var_trace, model_trace)
19+
return (log_weight, var_trace, model_trace)
2020
end
2121

2222
function vimco_geometric_baselines(log_weights)
@@ -29,12 +29,12 @@ function vimco_geometric_baselines(log_weights)
2929
baselines[i] = logsumexp(log_weights) - log(num_samples)
3030
log_weights[i] = temp
3131
end
32-
baselines
32+
return baselines
3333
end
3434

3535
function logdiffexp(x, y)
3636
m = max(x, y)
37-
m + log(exp(x - m) - exp(y - m))
37+
return m + log(exp(x - m) - exp(y - m))
3838
end
3939

4040
function vimco_arithmetic_baselines(log_weights)
@@ -46,7 +46,7 @@ function vimco_arithmetic_baselines(log_weights)
4646
log_f_hat = log_sum_f_without_i - log(num_samples - 1)
4747
baselines[i] = logsumexp(log_sum_f_without_i, log_f_hat) - log(num_samples)
4848
end
49-
baselines
49+
return baselines
5050
end
5151

5252
# black box, VIMCO gradient estimator
@@ -85,7 +85,7 @@ function multi_sample_gradient_estimate!(
8585

8686
# collection of traces and normalized importance weights, and estimate of
8787
# objective function
88-
(L, traces, weights_normalized)
88+
return (L, traces, weights_normalized)
8989
end
9090

9191
function _maybe_accumulate_param_grad!(trace, optimizer, scale_factor::Real)
@@ -117,6 +117,7 @@ update the parameters of `model`.
117117
- `callback`: Callback function that takes `(iter, traces, elbo_estimate)`
118118
as input, where `iter` is the iteration number and `traces` are samples
119119
from `var_model` for that iteration.
120+
- `multithreaded`: if `true`, gradient estimation may use multiple threads.
120121
"""
121122
function black_box_vi!(
122123
model::GenerativeFunction, model_args::Tuple,
@@ -125,31 +126,32 @@ function black_box_vi!(
125126
var_model::GenerativeFunction, var_model_args::Tuple,
126127
var_model_optimizer;
127128
iters=1000, samples_per_iter=100, verbose=false,
128-
callback=(iter, traces, elbo_estimate) -> nothing)
129+
callback=(iter, traces, elbo_estimate) -> nothing,
130+
multithreaded=false)
129131

130132
var_traces = Vector{Any}(undef, samples_per_iter)
131133
model_traces = Vector{Any}(undef, samples_per_iter)
134+
log_weights = Vector{Float64}(undef, samples_per_iter)
132135
elbo_history = Vector{Float64}(undef, iters)
133136
for iter=1:iters
134137

135138
# compute gradient estimate and objective function estimate
136-
elbo_estimate = 0.0
137-
# TODO multithread (note that this would require accumulate_param_gradients! to be threadsafe)
138-
for sample=1:samples_per_iter
139-
140-
# accumulate the variational family gradients
141-
(log_weight, var_trace, model_trace) = single_sample_gradient_estimate!(
142-
var_model, var_model_args,
143-
model, model_args, observations, 1/samples_per_iter)
144-
elbo_estimate += (log_weight / samples_per_iter)
145-
146-
# accumulate the generative model gradients
147-
_maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / samples_per_iter)
148-
149-
# record the traces
150-
var_traces[sample] = var_trace
151-
model_traces[sample] = model_trace
139+
if multithreaded
140+
Threads.@threads for i in 1:samples_per_iter
141+
black_box_vi_iter!(
142+
var_traces, model_traces, log_weights, i, samples_per_iter,
143+
var_model, var_model_args,
144+
model, model_args, observations, model_optimizer)
145+
end
146+
else
147+
for i in 1:samples_per_iter
148+
black_box_vi_iter!(
149+
var_traces, model_traces, log_weights, i, samples_per_iter,
150+
var_model, var_model_args,
151+
model, model_args, observations, model_optimizer)
152+
end
152153
end
154+
elbo_estimate = sum(log_weights)
153155
elbo_history[iter] = elbo_estimate
154156

155157
# print it
@@ -167,9 +169,34 @@ function black_box_vi!(
167169
end
168170
end
169171

170-
(elbo_history[end], var_traces, elbo_history, model_traces)
172+
return (elbo_history[end], var_traces, elbo_history, model_traces)
173+
end
174+
175+
function black_box_vi_iter!(
176+
var_traces::Vector, model_traces::Vector, log_weights::Vector{Float64},
177+
i::Int, n::Int,
178+
var_model::GenerativeFunction, var_model_args::Tuple,
179+
model::GenerativeFunction, model_args::Tuple,
180+
observations::ChoiceMap,
181+
model_optimizer)
182+
183+
# accumulate the variational family gradients
184+
(log_weight, var_trace, model_trace) = single_sample_gradient_estimate!(
185+
var_model, var_model_args,
186+
model, model_args, observations, 1.0 / n)
187+
log_weights[i] = log_weight / n
188+
189+
# accumulate the generative model gradients
190+
_maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / n)
191+
192+
# record the traces
193+
var_traces[i] = var_trace
194+
model_traces[i] = model_trace
195+
196+
return nothing
171197
end
172198

199+
173200
black_box_vi!(model::GenerativeFunction, model_args::Tuple,
174201
observations::ChoiceMap,
175202
var_model::GenerativeFunction, var_model_args::Tuple,
@@ -205,42 +232,45 @@ update the parameters of `model`.
205232
- `callback`: Callback function that takes `(iter, traces, elbo_estimate)`
206233
as input, where `iter` is the iteration number and `traces` are samples
207234
from `var_model` for that iteration.
235+
- `multithreaded`: if `true`, gradient estimation may use multiple threads.
208236
"""
209237
function black_box_vimco!(
210238
model::GenerativeFunction, model_args::Tuple,
211239
model_optimizer::Union{CompositeOptimizer,Nothing}, observations::ChoiceMap,
212240
var_model::GenerativeFunction, var_model_args::Tuple,
213241
var_model_optimizer::CompositeOptimizer, grad_est_samples::Int;
214242
iters=1000, samples_per_iter=100, geometric=true, verbose=false,
215-
callback=(iter, traces, elbo_estimate) -> nothing)
243+
callback=(iter, traces, elbo_estimate) -> nothing,
244+
multithreaded=false)
216245

217246
resampled_var_traces = Vector{Any}(undef, samples_per_iter)
218247
model_traces = Vector{Any}(undef, samples_per_iter)
248+
log_weights = Vector{Float64}(undef, samples_per_iter)
219249

220250
iwelbo_history = Vector{Float64}(undef, iters)
221251
for iter=1:iters
222252

223253
# compute gradient estimate and objective function estimate
224-
iwelbo_estimate = 0.
225-
for sample=1:samples_per_iter
226-
227-
# accumulate the variational family gradients
228-
(est, original_var_traces, weights) = multi_sample_gradient_estimate!(
229-
var_model, var_model_args,
230-
model, model_args, observations, grad_est_samples,
231-
1/samples_per_iter, geometric)
232-
iwelbo_estimate += (est / samples_per_iter)
233-
234-
# record a variational trace obtained by resampling from the weighted collection
235-
resampled_var_traces[sample] = original_var_traces[categorical(weights)]
236-
237-
# accumulate the generative model gradient estimator
238-
for (var_trace, weight) in zip(original_var_traces, weights)
239-
constraints = merge(observations, get_choices(var_trace))
240-
(model_trace, _) = generate(model, model_args, constraints)
241-
_maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter)
254+
if multithreaded
255+
Threads.@threads for i in 1:samples_per_iter
256+
black_box_vimco_iter!(
257+
resampled_var_traces, log_weights,
258+
i, samples_per_iter,
259+
var_model, var_model_args, model, model_args,
260+
observations, geometric, grad_est_samples,
261+
model_optimizer)
262+
end
263+
else
264+
for i in 1:samples_per_iter
265+
black_box_vimco_iter!(
266+
resampled_var_traces, log_weights,
267+
i, samples_per_iter,
268+
var_model, var_model_args, model, model_args,
269+
observations, geometric, grad_est_samples,
270+
model_optimizer)
242271
end
243272
end
273+
iwelbo_estimate = sum(log_weights)
244274
iwelbo_history[iter] = iwelbo_estimate
245275

246276
# print it
@@ -262,6 +292,34 @@ function black_box_vimco!(
262292
(iwelbo_history[end], resampled_var_traces, iwelbo_history, model_traces)
263293
end
264294

295+
function black_box_vimco_iter!(
296+
resampled_var_traces::Vector, log_weights::Vector{Float64},
297+
i::Int, samples_per_iter::Int,
298+
var_model::GenerativeFunction, var_model_args::Tuple,
299+
model::GenerativeFunction, model_args::Tuple,
300+
observations::ChoiceMap, geometric::Bool, grad_est_samples::Int,
301+
model_optimizer)
302+
303+
# accumulate the variational family gradients
304+
(est, original_var_traces, weights) = multi_sample_gradient_estimate!(
305+
var_model, var_model_args,
306+
model, model_args, observations, grad_est_samples,
307+
1/samples_per_iter, geometric)
308+
log_weights[i] = est / samples_per_iter
309+
310+
# record a variational trace obtained by resampling from the weighted collection
311+
resampled_var_traces[i] = original_var_traces[categorical(weights)]
312+
313+
# accumulate the generative model gradient estimator
314+
for (var_trace, weight) in zip(original_var_traces, weights)
315+
constraints = merge(observations, get_choices(var_trace))
316+
(model_trace, _) = generate(model, model_args, constraints)
317+
_maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter)
318+
end
319+
320+
return nothing
321+
end
322+
265323
black_box_vimco!(model::GenerativeFunction, model_args::Tuple,
266324
observations::ChoiceMap,
267325
var_model::GenerativeFunction, var_model_args::Tuple,

src/optimization.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ Thread-safe (multiple threads can increment the gradient of the same parameter c
375375
function increment_gradient!(
376376
id::Tuple{GenerativeFunction,Symbol}, increment,
377377
store::JuliaParameterStore=default_julia_parameter_store)
378-
accumulator = get_gradient_accumulator(store, id)
378+
accumulator = get_gradient_accumulator(id, store)
379379
in_place_add!(accumulator, increment)
380380
return nothing
381381
end
@@ -555,5 +555,3 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)
555555
end
556556

557557
propose(gen_fn::GenerativeFunction, args::Tuple) = propose(gen_fn, args, default_parameter_context)
558-
559-

test/inference/importance_sampling.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,36 @@
1515

1616
n = 4
1717

18-
(traces, log_weights, lml_est) = importance_sampling(model, (), observations, n)
19-
@test length(traces) == n
20-
@test length(log_weights) == n
21-
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
22-
@test !isnan(lml_est)
23-
for trace in traces
24-
@test get_choices(trace)[:y] == y
18+
for multithreaded in [false, true]
19+
(traces, log_weights, lml_est) = importance_sampling(
20+
model, (), observations, n; multithreaded=multithreaded)
21+
@test length(traces) == n
22+
@test length(log_weights) == n
23+
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
24+
@test !isnan(lml_est)
25+
for trace in traces
26+
@test get_choices(trace)[:y] == y
27+
end
2528
end
2629

27-
(traces, log_weights, lml_est) = importance_sampling(model, (), observations, proposal, (), n)
28-
@test length(traces) == n
29-
@test length(log_weights) == n
30-
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
31-
@test !isnan(lml_est)
32-
for trace in traces
33-
@test get_choices(trace)[:y] == y
30+
for multithreaded in [false, true]
31+
(traces, log_weights, lml_est) = importance_sampling(
32+
model, (), observations, proposal, (), n;
33+
multithreaded=multithreaded)
34+
@test length(traces) == n
35+
@test length(log_weights) == n
36+
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
37+
@test !isnan(lml_est)
38+
for trace in traces
39+
@test get_choices(trace)[:y] == y
40+
end
3441
end
3542

3643
(trace, lml_est) = importance_resampling(model, (), observations, n)
37-
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
3844
@test !isnan(lml_est)
3945
@test get_choices(trace)[:y] == y
4046

4147
(trace, lml_est) = importance_resampling(model, (), observations, proposal, (), n)
42-
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
4348
@test !isnan(lml_est)
4449
@test get_choices(trace)[:y] == y
4550
end
46-
47-

test/inference/variational.jl

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,46 @@
1717
end
1818
register_parameters!(approx, [:slope_mu, :slope_log_std, :intercept_mu, :intercept_log_std])
1919

20-
# to regular black box variational inference
21-
init_parameter!((approx, :slope_mu), 0.0)
22-
init_parameter!((approx, :slope_log_std), 0.0)
23-
init_parameter!((approx, :intercept_mu), 0.0)
24-
init_parameter!((approx, :intercept_log_std), 0.0)
25-
2620
observations = choicemap()
27-
optimizer = init_optimizer(DecayStepGradientDescent(1, 100000), approx)
2821
optimizer = init_optimizer(DecayStepGradientDescent(1., 1000), approx)
29-
black_box_vi!(model, (), observations, approx, (), optimizer;
30-
iters=2000, samples_per_iter=100, verbose=false)
31-
slope_mu = get_parameter_value((approx, :slope_mu))
32-
slope_log_std = get_parameter_value((approx, :slope_log_std))
33-
intercept_mu = get_parameter_value((approx, :intercept_mu))
34-
intercept_log_std = get_parameter_value((approx, :intercept_log_std))
35-
@test isapprox(slope_mu, -1., atol=0.001)
36-
@test isapprox(slope_log_std, 0.5, atol=0.001)
37-
@test isapprox(intercept_mu, 1., atol=0.001)
38-
@test isapprox(intercept_log_std, 2.0, atol=0.001)
22+
23+
# test regular black box variational inference
24+
for multithreaded in [false, true]
25+
init_parameter!((approx, :slope_mu), 0.0)
26+
init_parameter!((approx, :slope_log_std), 0.0)
27+
init_parameter!((approx, :intercept_mu), 0.0)
28+
init_parameter!((approx, :intercept_log_std), 0.0)
29+
black_box_vi!(model, (), observations, approx, (), optimizer;
30+
iters=2000, samples_per_iter=100, verbose=false, multithreaded=multithreaded)
31+
32+
slope_mu = get_parameter_value((approx, :slope_mu))
33+
slope_log_std = get_parameter_value((approx, :slope_log_std))
34+
intercept_mu = get_parameter_value((approx, :intercept_mu))
35+
intercept_log_std = get_parameter_value((approx, :intercept_log_std))
36+
@test isapprox(slope_mu, -1., atol=0.001)
37+
@test isapprox(slope_log_std, 0.5, atol=0.001)
38+
@test isapprox(intercept_mu, 1., atol=0.001)
39+
@test isapprox(intercept_log_std, 2.0, atol=0.001)
40+
end
3941

4042
# smoke test for black box variational inference with Monte Carlo objectives
41-
init_parameter!((approx, :slope_mu), 0.0)
42-
init_parameter!((approx, :slope_log_std), 0.0)
43-
init_parameter!((approx, :intercept_mu), 0.0)
44-
init_parameter!((approx, :intercept_log_std), 0.0)
45-
black_box_vimco!(model, (), observations, approx, (), optimizer, 20;
46-
iters=50, samples_per_iter=100, verbose=false, geometric=false)
47-
48-
init_parameter!((approx, :slope_mu), 0.0)
49-
init_parameter!((approx, :slope_log_std), 0.0)
50-
init_parameter!((approx, :intercept_mu), 0.0)
51-
init_parameter!((approx, :intercept_log_std), 0.0)
52-
black_box_vimco!(model, (), observations, approx, (), optimizer, 20;
53-
iters=50, samples_per_iter=100, verbose=false, geometric=true)
43+
for multithreaded in [false, true]
44+
init_parameter!((approx, :slope_mu), 0.0)
45+
init_parameter!((approx, :slope_log_std), 0.0)
46+
init_parameter!((approx, :intercept_mu), 0.0)
47+
init_parameter!((approx, :intercept_log_std), 0.0)
48+
black_box_vimco!(model, (), observations, approx, (), optimizer, 20;
49+
iters=50, samples_per_iter=100, verbose=false, geometric=false,
50+
multithreaded=multithreaded)
51+
52+
init_parameter!((approx, :slope_mu), 0.0)
53+
init_parameter!((approx, :slope_log_std), 0.0)
54+
init_parameter!((approx, :intercept_mu), 0.0)
55+
init_parameter!((approx, :intercept_log_std), 0.0)
56+
black_box_vimco!(model, (), observations, approx, (), optimizer, 20;
57+
iters=50, samples_per_iter=100, verbose=false, geometric=true,
58+
multithreaded=multithreaded)
59+
end
5460

5561
end
5662

0 commit comments

Comments
 (0)