@@ -16,7 +16,7 @@ function single_sample_gradient_estimate!(
1616 accumulate_param_gradients! (var_trace, nothing , log_weight * scale_factor)
1717
1818 # unbiased estimate of objective function, and trace
19- (log_weight, var_trace, model_trace)
19+ return (log_weight, var_trace, model_trace)
2020end
2121
2222function vimco_geometric_baselines (log_weights)
@@ -29,12 +29,12 @@ function vimco_geometric_baselines(log_weights)
2929 baselines[i] = logsumexp (log_weights) - log (num_samples)
3030 log_weights[i] = temp
3131 end
32- baselines
32+ return baselines
3333end
3434
3535function logdiffexp (x, y)
3636 m = max (x, y)
37- m + log (exp (x - m) - exp (y - m))
37+ return m + log (exp (x - m) - exp (y - m))
3838end
3939
4040function vimco_arithmetic_baselines (log_weights)
@@ -46,7 +46,7 @@ function vimco_arithmetic_baselines(log_weights)
4646 log_f_hat = log_sum_f_without_i - log (num_samples - 1 )
4747 baselines[i] = logsumexp (log_sum_f_without_i, log_f_hat) - log (num_samples)
4848 end
49- baselines
49+ return baselines
5050end
5151
5252# black box, VIMCO gradient estimator
@@ -85,7 +85,7 @@ function multi_sample_gradient_estimate!(
8585
8686 # collection of traces and normalized importance weights, and estimate of
8787 # objective function
88- (L, traces, weights_normalized)
88+ return (L, traces, weights_normalized)
8989end
9090
9191function _maybe_accumulate_param_grad! (trace, optimizer, scale_factor:: Real )
@@ -117,6 +117,7 @@ update the parameters of `model`.
117117- `callback`: Callback function that takes `(iter, traces, elbo_estimate)`
118118 as input, where `iter` is the iteration number and `traces` are samples
119119 from `var_model` for that iteration.
120+ - `multithreaded`: if `true`, gradient estimation may use multiple threads.
120121"""
121122function black_box_vi! (
122123 model:: GenerativeFunction , model_args:: Tuple ,
@@ -125,31 +126,32 @@ function black_box_vi!(
125126 var_model:: GenerativeFunction , var_model_args:: Tuple ,
126127 var_model_optimizer;
127128 iters= 1000 , samples_per_iter= 100 , verbose= false ,
128- callback= (iter, traces, elbo_estimate) -> nothing )
129+ callback= (iter, traces, elbo_estimate) -> nothing ,
130+ multithreaded= false )
129131
130132 var_traces = Vector {Any} (undef, samples_per_iter)
131133 model_traces = Vector {Any} (undef, samples_per_iter)
134+ log_weights = Vector {Float64} (undef, samples_per_iter)
132135 elbo_history = Vector {Float64} (undef, iters)
133136 for iter= 1 : iters
134137
135138 # compute gradient estimate and objective function estimate
136- elbo_estimate = 0.0
137- # TODO multithread (note that this would require accumulate_param_gradients! to be threadsafe)
138- for sample= 1 : samples_per_iter
139-
140- # accumulate the variational family gradients
141- (log_weight, var_trace, model_trace) = single_sample_gradient_estimate! (
142- var_model, var_model_args,
143- model, model_args, observations, 1 / samples_per_iter)
144- elbo_estimate += (log_weight / samples_per_iter)
145-
146- # accumulate the generative model gradients
147- _maybe_accumulate_param_grad! (model_trace, model_optimizer, 1.0 / samples_per_iter)
148-
149- # record the traces
150- var_traces[sample] = var_trace
151- model_traces[sample] = model_trace
139+ if multithreaded
140+ Threads. @threads for i in 1 : samples_per_iter
141+ black_box_vi_iter! (
142+ var_traces, model_traces, log_weights, i, samples_per_iter,
143+ var_model, var_model_args,
144+ model, model_args, observations, model_optimizer)
145+ end
146+ else
147+ for i in 1 : samples_per_iter
148+ black_box_vi_iter! (
149+ var_traces, model_traces, log_weights, i, samples_per_iter,
150+ var_model, var_model_args,
151+ model, model_args, observations, model_optimizer)
152+ end
152153 end
154+ elbo_estimate = sum (log_weights)
153155 elbo_history[iter] = elbo_estimate
154156
155157 # print it
@@ -167,9 +169,34 @@ function black_box_vi!(
167169 end
168170 end
169171
170- (elbo_history[end ], var_traces, elbo_history, model_traces)
172+ return (elbo_history[end ], var_traces, elbo_history, model_traces)
173+ end
174+
175+ function black_box_vi_iter! (
176+ var_traces:: Vector , model_traces:: Vector , log_weights:: Vector{Float64} ,
177+ i:: Int , n:: Int ,
178+ var_model:: GenerativeFunction , var_model_args:: Tuple ,
179+ model:: GenerativeFunction , model_args:: Tuple ,
180+ observations:: ChoiceMap ,
181+ model_optimizer)
182+
183+ # accumulate the variational family gradients
184+ (log_weight, var_trace, model_trace) = single_sample_gradient_estimate! (
185+ var_model, var_model_args,
186+ model, model_args, observations, 1.0 / n)
187+ log_weights[i] = log_weight / n
188+
189+ # accumulate the generative model gradients
190+ _maybe_accumulate_param_grad! (model_trace, model_optimizer, 1.0 / n)
191+
192+ # record the traces
193+ var_traces[i] = var_trace
194+ model_traces[i] = model_trace
195+
196+ return nothing
171197end
172198
199+
173200black_box_vi! (model:: GenerativeFunction , model_args:: Tuple ,
174201 observations:: ChoiceMap ,
175202 var_model:: GenerativeFunction , var_model_args:: Tuple ,
@@ -205,42 +232,45 @@ update the parameters of `model`.
205232- `callback`: Callback function that takes `(iter, traces, elbo_estimate)`
206233 as input, where `iter` is the iteration number and `traces` are samples
207234 from `var_model` for that iteration.
235+ - `multithreaded`: if `true`, gradient estimation may use multiple threads.
208236"""
209237function black_box_vimco! (
210238 model:: GenerativeFunction , model_args:: Tuple ,
211239 model_optimizer:: Union{CompositeOptimizer,Nothing} , observations:: ChoiceMap ,
212240 var_model:: GenerativeFunction , var_model_args:: Tuple ,
213241 var_model_optimizer:: CompositeOptimizer , grad_est_samples:: Int ;
214242 iters= 1000 , samples_per_iter= 100 , geometric= true , verbose= false ,
215- callback= (iter, traces, elbo_estimate) -> nothing )
243+ callback= (iter, traces, elbo_estimate) -> nothing ,
244+ multithreaded= false )
216245
217246 resampled_var_traces = Vector {Any} (undef, samples_per_iter)
218247 model_traces = Vector {Any} (undef, samples_per_iter)
248+ log_weights = Vector {Float64} (undef, samples_per_iter)
219249
220250 iwelbo_history = Vector {Float64} (undef, iters)
221251 for iter= 1 : iters
222252
223253 # compute gradient estimate and objective function estimate
224- iwelbo_estimate = 0.
225- for sample= 1 : samples_per_iter
226-
227- # accumulate the variational family gradients
228- (est, original_var_traces, weights) = multi_sample_gradient_estimate! (
229- var_model, var_model_args,
230- model, model_args, observations, grad_est_samples,
231- 1 / samples_per_iter, geometric)
232- iwelbo_estimate += (est / samples_per_iter)
233-
234- # record a variational trace obtained by resampling from the weighted collection
235- resampled_var_traces[sample] = original_var_traces[categorical (weights)]
236-
237- # accumulate the generative model gradient estimator
238- for (var_trace, weight) in zip (original_var_traces, weights)
239- constraints = merge (observations, get_choices (var_trace))
240- (model_trace, _) = generate (model, model_args, constraints)
241- _maybe_accumulate_param_grad! (model_trace, model_optimizer, weight / samples_per_iter)
254+ if multithreaded
255+ Threads. @threads for i in 1 : samples_per_iter
256+ black_box_vimco_iter! (
257+ resampled_var_traces, log_weights,
258+ i, samples_per_iter,
259+ var_model, var_model_args, model, model_args,
260+ observations, geometric, grad_est_samples,
261+ model_optimizer)
262+ end
263+ else
264+ for i in 1 : samples_per_iter
265+ black_box_vimco_iter! (
266+ resampled_var_traces, log_weights,
267+ i, samples_per_iter,
268+ var_model, var_model_args, model, model_args,
269+ observations, geometric, grad_est_samples,
270+ model_optimizer)
242271 end
243272 end
273+ iwelbo_estimate = sum (log_weights)
244274 iwelbo_history[iter] = iwelbo_estimate
245275
246276 # print it
@@ -262,6 +292,34 @@ function black_box_vimco!(
262292 (iwelbo_history[end ], resampled_var_traces, iwelbo_history, model_traces)
263293end
264294
295+ function black_box_vimco_iter! (
296+ resampled_var_traces:: Vector , log_weights:: Vector{Float64} ,
297+ i:: Int , samples_per_iter:: Int ,
298+ var_model:: GenerativeFunction , var_model_args:: Tuple ,
299+ model:: GenerativeFunction , model_args:: Tuple ,
300+ observations:: ChoiceMap , geometric:: Bool , grad_est_samples:: Int ,
301+ model_optimizer)
302+
303+ # accumulate the variational family gradients
304+ (est, original_var_traces, weights) = multi_sample_gradient_estimate! (
305+ var_model, var_model_args,
306+ model, model_args, observations, grad_est_samples,
307+ 1 / samples_per_iter, geometric)
308+ log_weights[i] = est / samples_per_iter
309+
310+ # record a variational trace obtained by resampling from the weighted collection
311+ resampled_var_traces[i] = original_var_traces[categorical (weights)]
312+
313+ # accumulate the generative model gradient estimator
314+ for (var_trace, weight) in zip (original_var_traces, weights)
315+ constraints = merge (observations, get_choices (var_trace))
316+ (model_trace, _) = generate (model, model_args, constraints)
317+ _maybe_accumulate_param_grad! (model_trace, model_optimizer, weight / samples_per_iter)
318+ end
319+
320+ return nothing
321+ end
322+
265323black_box_vimco! (model:: GenerativeFunction , model_args:: Tuple ,
266324 observations:: ChoiceMap ,
267325 var_model:: GenerativeFunction , var_model_args:: Tuple ,
0 commit comments